From 44f59b4301c51f071daa2e951301bb17c14acc9b Mon Sep 17 00:00:00 2001 From: Borislav Stanimirov Date: Fri, 27 Sep 2024 10:42:06 +0300 Subject: [PATCH 01/33] cmake : add option for common library (#9661) --- CMakeLists.txt | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 973907819..415743c2a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -62,6 +62,9 @@ option(LLAMA_SANITIZE_THREAD "llama: enable thread sanitizer" OFF) option(LLAMA_SANITIZE_ADDRESS "llama: enable address sanitizer" OFF) option(LLAMA_SANITIZE_UNDEFINED "llama: enable undefined sanitizer" OFF) +# utils +option(LLAMA_BUILD_COMMON "llama: build common utils library" ON) + # extra artifacts option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE}) option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE}) @@ -191,15 +194,17 @@ install(FILES "${CMAKE_CURRENT_BINARY_DIR}/llama.pc" DESTINATION lib/pkgconfig) # -# programs, examples and tests +# utils, programs, examples and tests # -add_subdirectory(common) +if (LLAMA_BUILD_COMMON) + add_subdirectory(common) +endif() if (LLAMA_BUILD_TESTS AND NOT CMAKE_JS_VERSION) include(CTest) add_subdirectory(tests) -endif () +endif() if (LLAMA_BUILD_EXAMPLES) add_subdirectory(examples) From b5de3b74a595cbfefab7eeb5a567425c6a9690cf Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 27 Sep 2024 20:57:51 +0300 Subject: [PATCH 02/33] readme : update hot topics --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index ce954f713..93225b63c 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,8 @@ Inference of Meta's [LLaMA](https://arxiv.org/abs/2302.13971) model (and others) ## Hot topics -- Huggingface GGUF editor: [discussion](https://github.com/ggerganov/llama.cpp/discussions/9268) | [tool](https://huggingface.co/spaces/CISCai/gguf-editor) +- **Hugging Face Inference Endpoints now support GGUF out of the box! https://github.com/ggerganov/llama.cpp/discussions/9669** +- Hugging Face GGUF editor: [discussion](https://github.com/ggerganov/llama.cpp/discussions/9268) | [tool](https://huggingface.co/spaces/CISCai/gguf-editor) ---- From 89f9944981010d195e411a9fbfbb19959412f710 Mon Sep 17 00:00:00 2001 From: Markus Tavenrath Date: Sat, 28 Sep 2024 12:05:05 +0200 Subject: [PATCH 03/33] Enable use to the rebar feature to upload buffers to the device. (#9251) --- ggml/src/ggml-vulkan.cpp | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-vulkan.cpp b/ggml/src/ggml-vulkan.cpp index f9da45881..a877145e8 100644 --- a/ggml/src/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan.cpp @@ -1079,7 +1079,8 @@ static vk_buffer ggml_vk_create_buffer_device(vk_device& device, size_t size) { // Fall back to host memory type buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eDeviceLocal, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent); } else { - buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eDeviceLocal); + // use rebar if available, otherwise fallback to device only visible memory + buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eDeviceLocal | vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent, vk::MemoryPropertyFlagBits::eDeviceLocal); } } catch (const vk::SystemError& e) { std::cerr << "ggml_vulkan: Device memory allocation of size " << size << " failed." << std::endl; @@ -2806,7 +2807,11 @@ static void ggml_vk_buffer_read_async(vk_context subctx, vk_buffer& src, size_t static void ggml_vk_buffer_read(vk_buffer& src, size_t offset, void * dst, size_t size) { VK_LOG_DEBUG("ggml_vk_buffer_read(" << src->buffer << ", " << offset << ", " << size << ")"); - if(src->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) { + + // If the device is not an UMA device the memory is host-accessible through rebar. While writing + // through PCIe is sufficient fast reading back data from PCIe is slower than going through + // the HW device to host copy path. + if(src->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible && src->device->uma) { GGML_ASSERT(src->memory_property_flags & vk::MemoryPropertyFlagBits::eHostCoherent); memcpy(dst, (uint8_t *) src->ptr + offset, size); From 6a0f7794847244fb3b99a983e03137d7e832b585 Mon Sep 17 00:00:00 2001 From: Dan Johansson <164997844+eddnjjn@users.noreply.github.com> Date: Sat, 28 Sep 2024 14:06:16 +0200 Subject: [PATCH 04/33] ggml : add run-time detection of neon, i8mm and sve (#9331) * ggml: Added run-time detection of neon, i8mm and sve Adds run-time detection of the Arm instructions set features neon, i8mm and sve for Linux and Apple build targets. * ggml: Extend feature detection to include non aarch64 Arm arch * ggml: Move definition of ggml_arm_arch_features to the global data section --- ggml/include/ggml.h | 3 ++ ggml/src/ggml-aarch64.c | 13 +----- ggml/src/ggml-quants.c | 4 +- ggml/src/ggml-quants.h | 4 -- ggml/src/ggml.c | 101 ++++++++++++++++++++++++++++++++++------ 5 files changed, 93 insertions(+), 32 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index e24b8a319..9f96e0c48 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -2507,6 +2507,9 @@ extern "C" { GGML_API int ggml_cpu_has_cann (void); GGML_API int ggml_cpu_has_llamafile (void); + // get the sve vector length in bytes + GGML_API int ggml_cpu_get_sve_cnt(void); + // // Internal types and functions exposed for tests and benchmarks // diff --git a/ggml/src/ggml-aarch64.c b/ggml/src/ggml-aarch64.c index 8912de63d..b27f41147 100644 --- a/ggml/src/ggml-aarch64.c +++ b/ggml/src/ggml-aarch64.c @@ -598,15 +598,6 @@ size_t quantize_q4_0_8x8(const float * restrict src, void * restrict dst, int64_ return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 8, 8); } -// Return the number of byte lanes in the SVE vector if SVE is supported; otherwise, returns 0 if SVE is not supported. -static int sve_lane_count(void) { -#if defined(__ARM_FEATURE_SVE) - return ggml_sve_cnt_b; -#else - return 0; -#endif -} - void ggml_gemv_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; @@ -843,7 +834,7 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) #if defined(__ARM_FEATURE_SVE) - if (ggml_cpu_has_sve() && sve_lane_count() == QK8_0) { + if (ggml_cpu_has_sve() && ggml_cpu_get_sve_cnt() == QK8_0) { const void * b_ptr = vx; const void * a_ptr = vy; float * res_ptr = s; @@ -2020,7 +2011,7 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) #if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8) - if (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && sve_lane_count() == QK8_0) { + if (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && ggml_cpu_get_sve_cnt() == QK8_0) { const void * b_ptr = vx; const void * a_ptr = vy; float * res_ptr = s; diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index 8bffce860..7aa6dce89 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -4013,7 +4013,7 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r svfloat32_t sumv0 = svdup_n_f32(0.0f); svfloat32_t sumv1 = svdup_n_f32(0.0f); - const int vector_length = ggml_sve_cnt_b*8; + const int vector_length = ggml_cpu_get_sve_cnt()*8; // VLA Implementation using switch case switch (vector_length) { @@ -5597,7 +5597,7 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r svfloat32_t sumv0 = svdup_n_f32(0.0f); svfloat32_t sumv1 = svdup_n_f32(0.0f); - const int vector_length = ggml_sve_cnt_b*8; + const int vector_length = ggml_cpu_get_sve_cnt()*8; //VLA Implemenation for SVE switch (vector_length) { diff --git a/ggml/src/ggml-quants.h b/ggml/src/ggml-quants.h index e96ce2b5e..df9c4b24a 100644 --- a/ggml/src/ggml-quants.h +++ b/ggml/src/ggml-quants.h @@ -142,10 +142,6 @@ void iq2xs_free_impl(enum ggml_type type); void iq3xs_init_impl(int grid_size); void iq3xs_free_impl(int grid_size); -#if defined(__ARM_FEATURE_SVE) -extern int ggml_sve_cnt_b; -#endif - #ifdef __cplusplus } #endif diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 4b782b0c1..fac4466e3 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -39,9 +39,6 @@ #include #endif -#if defined(__ARM_FEATURE_SVE) -int ggml_sve_cnt_b = 0; -#endif #if defined(__ARM_FEATURE_SVE) || defined(__ARM_FEATURE_MATMUL_INT8) #undef GGML_USE_LLAMAFILE #endif @@ -455,6 +452,15 @@ static ggml_fp16_t ggml_table_gelu_quick_f16[1 << 16]; // precomputed f32 table for f16 (256 KB) (ggml-impl.h) float ggml_table_f32_f16[1 << 16]; +#if defined(__ARM_ARCH) +struct ggml_arm_arch_features_type { + int has_neon; + int has_i8mm; + int has_sve; + int sve_cnt; +} ggml_arm_arch_features = {-1, -1, -1, 0}; +#endif + GGML_CALL const char * ggml_status_to_string(enum ggml_status status) { switch (status) { case GGML_STATUS_ALLOC_FAILED: return "GGML status: error (failed to allocate memory)"; @@ -3673,6 +3679,66 @@ static inline int ggml_up(int n, int m) { //////////////////////////////////////////////////////////////////////////////// +#if defined(__ARM_ARCH) + +#if defined(__linux__) && defined(__aarch64__) +#include +#elif defined(__APPLE__) +#include +#endif + +static void ggml_init_arm_arch_features(void) { +#if defined(__linux__) && defined(__aarch64__) + uint32_t hwcap = getauxval(AT_HWCAP); + uint32_t hwcap2 = getauxval(AT_HWCAP2); + + ggml_arm_arch_features.has_neon = !!(hwcap & HWCAP_ASIMD); + ggml_arm_arch_features.has_i8mm = !!(hwcap2 & HWCAP2_I8MM); + ggml_arm_arch_features.has_sve = !!(hwcap & HWCAP_SVE); + +#if defined(__ARM_FEATURE_SVE) + ggml_arm_arch_features.sve_cnt = PR_SVE_VL_LEN_MASK & prctl(PR_SVE_GET_VL); +#endif +#elif defined(__APPLE__) + int oldp = 0; + size_t size = sizeof(oldp); + if (sysctlbyname("hw.optional.AdvSIMD", &oldp, &size, NULL, 0) != 0) { + oldp = 0; + } + ggml_arm_arch_features.has_neon = oldp; + + if (sysctlbyname("hw.optional.arm.FEAT_I8MM", &oldp, &size, NULL, 0) != 0) { + oldp = 0; + } + ggml_arm_arch_features.has_i8mm = oldp; + + ggml_arm_arch_features.has_sve = 0; + ggml_arm_arch_features.sve_cnt = 0; +#else +// Run-time CPU feature detection not implemented for this platform, fallback to compile time +#if defined(__ARM_NEON) + ggml_arm_arch_features.has_neon = 1; +#else + ggml_arm_arch_features.has_neon = 0; +#endif + +#if defined(__ARM_FEATURE_MATMUL_INT8) + ggml_arm_arch_features.has_i8mm = 1; +#else + ggml_arm_arch_features.has_i8mm = 0; +#endif + +#if defined(__ARM_FEATURE_SVE) + ggml_arm_arch_features.has_sve = 1; + ggml_arm_arch_features.sve_cnt = 16; +#else + ggml_arm_arch_features.has_sve = 0; + ggml_arm_arch_features.sve_cnt = 0; +#endif +#endif +} +#endif + struct ggml_context * ggml_init(struct ggml_init_params params) { // make this function thread safe ggml_critical_section_start(); @@ -3723,6 +3789,10 @@ struct ggml_context * ggml_init(struct ggml_init_params params) { GGML_PRINT_DEBUG("%s: g_state initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f); } +#if defined(__ARM_ARCH) + ggml_init_arm_arch_features(); +#endif + is_first_call = false; } @@ -3771,12 +3841,6 @@ struct ggml_context * ggml_init(struct ggml_init_params params) { GGML_ASSERT_ALIGNED(ctx->mem_buffer); -#if defined(__ARM_FEATURE_SVE) - if (!ggml_sve_cnt_b) { - ggml_sve_cnt_b = PR_SVE_VL_LEN_MASK & prctl(PR_SVE_GET_VL); - } -#endif - GGML_PRINT_DEBUG("%s: context initialized\n", __func__); ggml_critical_section_end(); @@ -23578,16 +23642,16 @@ int ggml_cpu_has_fma(void) { } int ggml_cpu_has_neon(void) { -#if defined(__ARM_NEON) - return 1; +#if defined(__ARM_ARCH) + return ggml_arm_arch_features.has_neon; #else return 0; #endif } int ggml_cpu_has_sve(void) { -#if defined(__ARM_FEATURE_SVE) - return 1; +#if defined(__ARM_ARCH) + return ggml_arm_arch_features.has_sve; #else return 0; #endif @@ -23734,11 +23798,18 @@ int ggml_cpu_has_vsx(void) { } int ggml_cpu_has_matmul_int8(void) { -#if defined(__ARM_FEATURE_MATMUL_INT8) - return 1; +#if defined(__ARM_ARCH) + return ggml_arm_arch_features.has_i8mm; #else return 0; #endif } +int ggml_cpu_get_sve_cnt(void) { +#if defined(__ARM_ARCH) + return ggml_arm_arch_features.sve_cnt; +#else + return 0; +#endif +} //////////////////////////////////////////////////////////////////////////////// From 43bcdd9703ec19af7d2a519640b5ed6f4aac3d53 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Sat, 28 Sep 2024 15:07:14 +0300 Subject: [PATCH 05/33] readme : add tool (#9655) --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 93225b63c..a452a6d78 100644 --- a/README.md +++ b/README.md @@ -174,6 +174,7 @@ Unless otherwise noted these projects are open-source with permissive licensing: **Tools:** - [akx/ggify](https://github.com/akx/ggify) – download PyTorch models from HuggingFace Hub and convert them to GGML +- [akx/ollama-dl](https://github.com/akx/ollama-dl) – download models from the Ollama library to be used directly with llama.cpp - [crashr/gppm](https://github.com/crashr/gppm) – launch llama.cpp instances utilizing NVIDIA Tesla P40 or P100 GPUs with reduced idle power consumption - [gpustack/gguf-parser](https://github.com/gpustack/gguf-parser-go/tree/main/cmd/gguf-parser) - review/check the GGUF file and estimate the memory usage - [Styled Lines](https://marketplace.unity.com/packages/tools/generative-ai/styled-lines-llama-cpp-model-292902) (proprietary licensed, async wrapper of inference part for game development in Unity3d with prebuild Mobile and Web platform wrappers and a model example) From 9a913110cf471a8287ac06c43cbe307d3cf6df99 Mon Sep 17 00:00:00 2001 From: nopperl <54780682+nopperl@users.noreply.github.com> Date: Sat, 28 Sep 2024 12:08:43 +0000 Subject: [PATCH 06/33] llama : add support for Chameleon (#8543) * convert chameleon hf to gguf * add chameleon tokenizer tests * fix lint * implement chameleon graph * add swin norm param * return qk norm weights and biases to original format * implement swin norm * suppress image token output * rem tabs * add comment to conversion * fix ci * check for k norm separately * adapt to new lora implementation * fix layer input for swin norm * move swin_norm in gguf writer * add comment regarding special token regex in chameleon pre-tokenizer * Update src/llama.cpp Co-authored-by: compilade * fix punctuation regex in chameleon pre-tokenizer (@compilade) Co-authored-by: compilade * fix lint * trigger ci --------- Co-authored-by: compilade --- convert_hf_to_gguf.py | 44 +++++ convert_hf_to_gguf_update.py | 1 + gguf-py/gguf/constants.py | 19 ++ gguf-py/gguf/gguf_writer.py | 3 + gguf-py/gguf/tensor_mapping.py | 4 +- include/llama.h | 1 + models/ggml-vocab-chameleon.gguf.inp | 112 ++++++++++++ models/ggml-vocab-chameleon.gguf.out | 46 +++++ src/llama-vocab.cpp | 14 ++ src/llama.cpp | 263 +++++++++++++++++++++++++++ 10 files changed, 505 insertions(+), 2 deletions(-) create mode 100644 models/ggml-vocab-chameleon.gguf.inp create mode 100644 models/ggml-vocab-chameleon.gguf.out diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 7be609054..2cd5a8c11 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -640,6 +640,9 @@ class Model: if chkhsh == "fcace8b9cac38ce847670c970cd5892031a753a1ef381abd1d9af00f713da085": # ref: https://huggingface.co/microsoft/phi-2 res = "phi-2" + if chkhsh == "60824e3c0d9401f89943cbb2fff727f0e2d4c545ba4df2d6e4f09a6db0f5b450": + # ref: https://huggingface.co/facebook/chameleon-7b + res = "chameleon" if res is None: logger.warning("\n") @@ -4138,6 +4141,47 @@ class GraniteMoeModel(GraniteModel): return super().modify_tensors(data_torch, name, bid) +@Model.register("ChameleonForCausalLM") +class ChameleonModel(Model): + model_arch = gguf.MODEL_ARCH.CHAMELEON + + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_swin_norm(self.hparams.get("swin_norm", False)) + + def set_vocab(self): + self._set_vocab_gpt2() + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # ignore image tokenizer for now + # TODO: remove this once image support is implemented for Chameleon + if name.startswith("model.vqmodel"): + return [] + + n_head = self.hparams["num_attention_heads"] + n_kv_head = self.hparams.get("num_key_value_heads") + hidden_dim = self.hparams.get("hidden_size") + + if name.endswith(("q_proj.weight", "q_proj.bias")): + data_torch = LlamaModel.permute(data_torch, n_head, n_head) + if name.endswith(("k_proj.weight", "k_proj.bias")): + data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head) + if name.endswith(("q_norm.weight", "q_norm.bias")): + data_torch = ChameleonModel._reverse_hf_permute(data_torch, n_head, hidden_dim) + if name.endswith(("k_norm.weight", "k_norm.bias")): + data_torch = ChameleonModel._reverse_hf_permute(data_torch, n_kv_head, hidden_dim) + + return [(self.map_tensor_name(name), data_torch)] + + # see: https://github.com/huggingface/transformers/blob/72fb02c47dbbe1999ae105319f24631cad6e2e00/src/transformers/models/chameleon/convert_chameleon_weights_to_hf.py#L176-L203 + @staticmethod + def _reverse_hf_permute(data_torch, n_heads, hidden_dim): + head_dim = hidden_dim // n_heads + data_torch = data_torch[0].view(2, head_dim // 2).t().reshape(1, -1) + data_torch = data_torch.repeat_interleave(n_heads, 0) + return data_torch + + ###### CONVERSION LOGIC ###### diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py index 021f65abd..4d11059f3 100755 --- a/convert_hf_to_gguf_update.py +++ b/convert_hf_to_gguf_update.py @@ -99,6 +99,7 @@ models = [ {'name': "gpt3-finnish", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/TurkuNLP/gpt3-finnish-small", }, {"name": "exaone", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct", }, {"name": "phi-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/microsoft/phi-2", }, + {"name": "chameleon", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/facebook/chameleon-7b", }, ] diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 560eee916..2fd2e9d2b 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -94,6 +94,7 @@ class Keys: DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id" ATTN_LOGIT_SOFTCAPPING = "{arch}.attn_logit_softcapping" FINAL_LOGIT_SOFTCAPPING = "{arch}.final_logit_softcapping" + SWIN_NORM = "{arch}.swin_norm" RESCALE_EVERY_N_LAYERS = "{arch}.rescale_every_n_layers" TIME_MIX_EXTRA_DIM = "{arch}.time_mix_extra_dim" TIME_DECAY_EXTRA_DIM = "{arch}.time_decay_extra_dim" @@ -236,6 +237,7 @@ class MODEL_ARCH(IntEnum): EXAONE = auto() GRANITE = auto() GRANITE_MOE = auto() + CHAMELEON = auto() class MODEL_TENSOR(IntEnum): @@ -394,6 +396,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.EXAONE: "exaone", MODEL_ARCH.GRANITE: "granite", MODEL_ARCH.GRANITE_MOE: "granitemoe", + MODEL_ARCH.CHAMELEON: "chameleon", } TENSOR_NAMES: dict[MODEL_TENSOR, str] = { @@ -1260,6 +1263,22 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_DOWN_EXP, MODEL_TENSOR.FFN_UP_EXP, ], + MODEL_ARCH.CHAMELEON: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], # TODO } diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index bd059b45c..5c460ef1b 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -670,6 +670,9 @@ class GGUFWriter: def add_expert_weights_scale(self, value: float) -> None: self.add_float32(Keys.LLM.EXPERT_WEIGHTS_SCALE.format(arch=self.arch), value) + def add_swin_norm(self, value: bool) -> None: + self.add_bool(Keys.LLM.SWIN_NORM.format(arch=self.arch), value) + def add_rescale_every_n_layers(self, count: int) -> None: self.add_uint32(Keys.LLM.RESCALE_EVERY_N_LAYERS.format(arch=self.arch), count) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 4e850726e..5ef91f11d 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -380,7 +380,7 @@ class TensorNameMap: MODEL_TENSOR.ATTN_Q_NORM: ( "language_model.encoder.layers.{bid}.self_attention.q_layernorm", "model.layers.{bid}.self_attn.q_layernorm", # persimmon - "model.layers.{bid}.self_attn.q_norm", # cohere olmoe + "model.layers.{bid}.self_attn.q_norm", # cohere olmoe chameleon "transformer.blocks.{bid}.attn.q_ln", # sea-lion "encoder.layer.{bid}.attention.self.layer_norm_q", # jina-bert-v2 "transformer.layers.{bid}.attn.q_norm", # openelm @@ -389,7 +389,7 @@ class TensorNameMap: MODEL_TENSOR.ATTN_K_NORM: ( "language_model.encoder.layers.{bid}.self_attention.k_layernorm", "model.layers.{bid}.self_attn.k_layernorm", # persimmon - "model.layers.{bid}.self_attn.k_norm", # cohere olmoe + "model.layers.{bid}.self_attn.k_norm", # cohere olmoe chameleon "transformer.blocks.{bid}.attn.k_ln", # sea-lion "encoder.layer.{bid}.attention.self.layer_norm_k", # jina-bert-v2 "transformer.layers.{bid}.attn.k_norm", # openelm diff --git a/include/llama.h b/include/llama.h index 132937a07..caef0bfff 100644 --- a/include/llama.h +++ b/include/llama.h @@ -102,6 +102,7 @@ extern "C" { LLAMA_VOCAB_PRE_TYPE_BLOOM = 23, LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH = 24, LLAMA_VOCAB_PRE_TYPE_EXAONE = 25, + LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26, }; enum llama_rope_type { diff --git a/models/ggml-vocab-chameleon.gguf.inp b/models/ggml-vocab-chameleon.gguf.inp new file mode 100644 index 000000000..9baf7d77a --- /dev/null +++ b/models/ggml-vocab-chameleon.gguf.inp @@ -0,0 +1,112 @@ +ied 4 ½ months +__ggml_vocab_test__ +Führer +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + + +__ggml_vocab_test__ + + + +__ggml_vocab_test__ + + + + +__ggml_vocab_test__ + + +__ggml_vocab_test__ +Hello world +__ggml_vocab_test__ + Hello world +__ggml_vocab_test__ +Hello World +__ggml_vocab_test__ + Hello World +__ggml_vocab_test__ + Hello World! +__ggml_vocab_test__ +Hello, world! +__ggml_vocab_test__ + Hello, world! +__ggml_vocab_test__ + this is 🦙.cpp +__ggml_vocab_test__ +w048 7tuijk dsdfhu +__ggml_vocab_test__ +нещо на Български +__ggml_vocab_test__ +កាន់តែពិសេសអាចខលចេញ +__ggml_vocab_test__ +🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token) +__ggml_vocab_test__ +Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello + Hello +__ggml_vocab_test__ + ( +__ggml_vocab_test__ + + = +__ggml_vocab_test__ +' era +__ggml_vocab_test__ +Hello, y'all! How are you 😁 ?我想在apple工作1314151天~ +__ggml_vocab_test__ +!!!!!! +__ggml_vocab_test__ +3 +__ggml_vocab_test__ +33 +__ggml_vocab_test__ +333 +__ggml_vocab_test__ +3333 +__ggml_vocab_test__ +33333 +__ggml_vocab_test__ +333333 +__ggml_vocab_test__ +3333333 +__ggml_vocab_test__ +33333333 +__ggml_vocab_test__ +333333333 +__ggml_vocab_test__ +Cửa Việt +__ggml_vocab_test__ + discards +__ggml_vocab_test__ + + + + + + + + + + + +🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български ''''''```````""""......!!!!!!?????? I've been 'told he's there, 'RE you sure? 'M not sure I'll make it, 'D you like some tea? We'Ve a'lL +__ggml_vocab_test__ diff --git a/models/ggml-vocab-chameleon.gguf.out b/models/ggml-vocab-chameleon.gguf.out new file mode 100644 index 000000000..7c5413fee --- /dev/null +++ b/models/ggml-vocab-chameleon.gguf.out @@ -0,0 +1,46 @@ + 17245 16604 16403 16604 33583 18355 + 16421 51153 + + 16604 + 16650 + 16650 16604 + 16581 + 16582 + 16582 16582 + 16582 16582 16582 + 16581 16582 + 31596 17394 + 34926 17394 + 31596 18671 + 34926 18671 + 34926 18671 16384 + 31596 16395 17394 16384 + 34926 16395 17394 16384 + 16811 16704 20410 16483 16631 16397 52854 + 16470 16399 16403 16407 16604 16406 35764 38185 51595 22592 26639 + 29479 23955 17012 20103 25527 27670 17408 19005 21473 24774 + 54254 42231 48084 29409 16617 61889 29409 16608 21954 16628 21954 16499 58445 29409 16607 58445 21954 16479 42231 21954 16611 21954 16607 21954 16633 21954 16611 29409 16607 21954 16615 + 52351 16604 16391 25825 16392 23686 16498 39161 18885 16618 16488 30853 16604 16391 54124 17153 25134 16656 18476 26169 16895 16392 62193 16611 16604 16391 24664 17153 57169 16721 16872 17073 17304 28729 16392 + 31596 + 34926 + 16650 31596 + 16650 34926 + 16696 31596 + 16696 31596 16582 16696 31596 + 16604 16391 + 16582 16604 16412 + 16390 22623 + 31596 16395 16712 16390 16828 16384 17674 16769 16732 23686 16607 16604 16414 24427 16623 41809 16495 28999 36469 45292 30197 16400 16402 16400 16403 16400 16404 16400 43969 65211 16636 + 16384 16384 16384 16384 16384 16384 + 16402 + 16402 16402 + 16402 16402 16402 + 16402 16402 16402 16402 + 16402 16402 16402 16402 16402 + 16402 16402 16402 16402 16402 16402 + 16402 16402 16402 16402 16402 16402 16402 + 16402 16402 16402 16402 16402 16402 16402 16402 + 16402 16402 16402 16402 16402 16402 16402 16402 16402 + 16418 19038 16639 16448 24315 33727 16467 + 18765 17981 + 16582 16604 16582 16582 16604 16582 16582 16582 16604 16581 16604 16581 16581 16604 16581 16582 16650 16582 16650 16604 16582 16696 16582 16696 16604 16582 52351 16604 16391 25825 16392 23686 16498 39161 18885 16618 16488 30853 16604 16391 54124 17153 25134 16656 18476 26169 16895 16392 62193 16611 20410 16483 16631 18885 16483 16631 16604 16402 16604 16402 16402 16604 16402 16402 16402 16604 16402 16402 16402 16402 16604 16402 16402 16402 16402 16402 16604 16402 16402 16402 16402 16402 16402 16604 16402 16402 16402 16402 16402 16402 16402 16604 16402 16402 16402 16402 16402 16402 16402 16402 16604 16402 16397 16402 16604 16402 16397 16397 16402 16604 16402 16397 16397 16397 16402 16604 54254 42231 48084 29409 16617 61889 29409 16608 21954 16628 21954 16499 58445 29409 16607 58445 21954 16479 42231 21954 16611 27683 16607 16604 16414 24427 16623 41809 16495 28999 36469 45292 30197 16400 16402 16400 16403 16400 16404 16400 43969 65211 16636 16604 16396 16396 16396 16396 16396 16396 16412 16412 16412 16412 16412 16412 16412 27268 23955 17012 20103 25527 27670 17408 19005 21473 24774 16604 16390 16390 16390 16390 16390 16390 16447 16447 16447 16447 16447 16447 16447 16385 16385 16385 16385 16397 16397 16397 16397 16397 16397 16384 16384 16384 16384 16384 16384 16414 16414 16414 16414 16414 16414 16687 16390 16690 16992 16604 16390 61797 16733 16390 16466 16986 16395 16604 16390 17879 16732 17811 16414 16604 16390 16428 16804 17811 16687 16390 16683 17190 16728 16395 16604 16390 16419 16732 16945 16991 25251 16414 17119 16390 38127 16641 16390 16459 16427 diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index a771eccda..146d416f7 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -450,6 +450,20 @@ struct llm_tokenizer_bpe { "[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))*((?=[\\p{L}])([^A-Z]))+|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))+((?=[\\p{L}])([^A-Z]))*|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", }; break; + case LLAMA_VOCAB_PRE_TYPE_CHAMELEON: + // Note: in theory, the special token (sentinel and image token) regex_exprs below + // are unnecessary, as they are split in `tokenizer_st_partition` anyway. + // However, since the upstream pre-tokenizer uses them, they are also + // included here (see https://huggingface.co/facebook/chameleon-7b). + regex_exprs = { + "", // Sentinel tokens + "(IMGIMG)((A|B|C|D|E|F|G|H|I){1,4})Z", // Image tokens + "([\\t\\n]| | )", // directly from tokenizer.json + "\\p{N}", // Individual digits + "[\\p{P}!-/:-@\\[-`{-~]", // Punctuation, Isolated + "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", + }; + break; default: // default regex for BPE tokenization pre-processing regex_exprs = { diff --git a/src/llama.cpp b/src/llama.cpp index 0accb1492..f450eaf9d 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -216,6 +216,7 @@ enum llm_arch { LLM_ARCH_RWKV6, LLM_ARCH_GRANITE, LLM_ARCH_GRANITE_MOE, + LLM_ARCH_CHAMELEON, LLM_ARCH_UNKNOWN, }; @@ -268,6 +269,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_RWKV6, "rwkv6" }, { LLM_ARCH_GRANITE, "granite" }, { LLM_ARCH_GRANITE_MOE, "granitemoe" }, + { LLM_ARCH_CHAMELEON, "chameleon" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -304,6 +306,7 @@ enum llm_kv { LLM_KV_DECODER_START_TOKEN_ID, LLM_KV_ATTN_LOGIT_SOFTCAPPING, LLM_KV_FINAL_LOGIT_SOFTCAPPING, + LLM_KV_SWIN_NORM, LLM_KV_RESCALE_EVERY_N_LAYERS, LLM_KV_TIME_MIX_EXTRA_DIM, LLM_KV_TIME_DECAY_EXTRA_DIM, @@ -411,6 +414,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" }, { LLM_KV_ATTN_LOGIT_SOFTCAPPING, "%s.attn_logit_softcapping" }, { LLM_KV_FINAL_LOGIT_SOFTCAPPING, "%s.final_logit_softcapping" }, + { LLM_KV_SWIN_NORM, "%s.swin_norm" }, { LLM_KV_RESCALE_EVERY_N_LAYERS, "%s.rescale_every_n_layers" }, { LLM_KV_TIME_MIX_EXTRA_DIM, "%s.time_mix_extra_dim" }, { LLM_KV_TIME_DECAY_EXTRA_DIM, "%s.time_decay_extra_dim" }, @@ -1499,6 +1503,25 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, }, }, + { + LLM_ARCH_CHAMELEON, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + }, + }, { LLM_ARCH_UNKNOWN, { @@ -2362,6 +2385,7 @@ struct llama_hparams { bool vocab_only; bool rope_finetuned; bool use_par_res; + bool swin_norm; uint32_t n_vocab; uint32_t n_ctx_train; // context size the model was trained on @@ -6084,6 +6108,18 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_CHAMELEON: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + hparams.f_norm_eps = 1e-5; // eps for qk-norm, torch default + ml.get_key(LLM_KV_SWIN_NORM, hparams.swin_norm); + + switch (hparams.n_layer) { + case 32: model.type = e_model::MODEL_7B; break; + case 48: model.type = e_model::MODEL_34B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; default: (void)0; } @@ -6341,6 +6377,11 @@ static void llm_load_vocab( } else if ( tokenizer_pre == "exaone") { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_EXAONE; + } else if ( + tokenizer_pre == "chameleon") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_CHAMELEON; + vocab.tokenizer_add_bos = true; + vocab.tokenizer_clean_spaces = false; } else { throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str())); } @@ -8728,6 +8769,45 @@ static bool llm_load_tensors( } } break; + case LLM_ARCH_CHAMELEON: + { + model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + + // output + { + model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (model.output == NULL) { + model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); + } + } + + for (int i = 0; i < n_layer; ++i) { + ggml_context * ctx_layer = ctx_for_layer(i); + ggml_context * ctx_split = ctx_for_layer_split(i); + + auto & layer = model.layers[i]; + + layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_q_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k, n_head}); + layer.attn_k_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k, n_head_kv}); + layer.attn_q_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {n_embd_head_k, n_head}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.attn_k_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {n_embd_head_k, n_head_kv}, llama_model_loader::TENSOR_NOT_REQUIRED); + + layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); + layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); + layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); + layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + + layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + + layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); + layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); + layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + } + } break; default: throw std::runtime_error("unknown architecture"); } @@ -15872,6 +15952,184 @@ struct llm_build_context { return gf; } + + // ref: https://github.com/facebookresearch/chameleon + // based on the original build_llama() function, changes: + // * qk-norm + // * swin-norm + // * removed bias + // * removed MoE + struct ggml_cgraph * build_chameleon() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + + // mutable variable, needed during the last layer of the computation to skip unused tokens + int32_t n_tokens = this->n_tokens; + + const int64_t n_embd_head = hparams.n_embd_head_v; + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = build_inp_pos(); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; + + // norm + if (hparams.swin_norm) { + cur = inpL; + } else { + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm", il); + } + + // self-attention + { + // compute Q and K and RoPE them + struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + if (model.layers[il].attn_q_norm) { + Qcur = ggml_view_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens, + ggml_element_size(Qcur) * n_embd_head, + ggml_element_size(Qcur) * n_embd_head * n_head, + 0); + cb(Qcur, "Qcur", il); + + Qcur = llm_build_norm(ctx0, Qcur, hparams, + model.layers[il].attn_q_norm, + model.layers[il].attn_q_norm_b, + LLM_NORM, cb, il); + cb(Qcur, "Qcur", il); + } + + if (model.layers[il].attn_k_norm) { + Kcur = ggml_view_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens, + ggml_element_size(Kcur) * n_embd_head, + ggml_element_size(Kcur) * n_embd_head * n_head_kv, + 0); + cb(Kcur, "Kcur", il); + + Kcur = llm_build_norm(ctx0, Kcur, hparams, + model.layers[il].attn_k_norm, + model.layers[il].attn_k_norm_b, + LLM_NORM, cb, il); + cb(Kcur, "Kcur", il); + } + + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Qcur, "Qcur", il); + + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Kcur, "Kcur", il); + + cur = llm_build_kv(ctx0, lctx, kv_self, gf, + model.layers[il].wo, nullptr, + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + + if (hparams.swin_norm) { + cur = llm_build_norm(ctx0, cur, hparams, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, cb, il); + } + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + n_tokens = n_outputs; + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + if (!hparams.swin_norm) { + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + } + + cur = llm_build_ffn(ctx0, lctx, cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + cb(cur, "ffn_out", il); + + if (hparams.swin_norm) { + cur = llm_build_norm(ctx0, cur, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = lctx.cvec.apply_to(ctx0, cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = llm_build_norm(ctx0, cur, hparams, + model.output_norm, NULL, + LLM_NORM_RMS, cb, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); + cb(cur, "result_output_with_img_logits", -1); + + // TODO: this suppresses the output of image tokens, which is required to enable text-only outputs. + // Needs to be removed once image outputs are supported. + int img_token_end_idx = 8196; + int img_token_start_idx = 4; + int num_img_tokens = img_token_end_idx - img_token_start_idx; + // creates 1d tensor of size num_img_tokens and values -FLT_MAX, + // which ensures that text token values are always at least larger than image token values + struct ggml_tensor * img_logits = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, num_img_tokens); + img_logits = ggml_clamp(ctx0, img_logits, -FLT_MAX, -FLT_MAX); + cb(img_logits, "img_logits", -1); + cur = ggml_set_1d(ctx0, cur, img_logits, ggml_element_size(cur) * img_token_start_idx); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } }; static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector & ids) { @@ -16132,6 +16390,10 @@ static struct ggml_cgraph * llama_build_graph( { result = llm.build_rwkv6(); } break; + case LLM_ARCH_CHAMELEON: + { + result = llm.build_chameleon(); + } break; default: GGML_ABORT("fatal error"); } @@ -19257,6 +19519,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { case LLM_ARCH_CHATGLM: case LLM_ARCH_GRANITE: case LLM_ARCH_GRANITE_MOE: + case LLM_ARCH_CHAMELEON: return LLAMA_ROPE_TYPE_NORM; // the pairs of head values are offset by n_rot/2 From 6102037bbb55521880ae78a6ee6c2a0c00c901df Mon Sep 17 00:00:00 2001 From: Zhenwei Jin <109658203+kylo5aby@users.noreply.github.com> Date: Sat, 28 Sep 2024 20:10:58 +0800 Subject: [PATCH 07/33] vocab : refactor tokenizer to reduce init overhead (#9449) * refactor tokenizer * llama : make llm_tokenizer more private ggml-ci * refactor tokenizer * refactor tokenizer * llama : make llm_tokenizer more private ggml-ci * remove unused files * remove unused fileds to avoid unused filed build error * avoid symbol link error * Update src/llama.cpp * Update src/llama.cpp --------- Co-authored-by: Georgi Gerganov --- .../convert-llama2c-to-ggml.cpp | 14 +- src/llama-vocab.cpp | 266 +++++++++++------- src/llama-vocab.h | 9 + src/llama.cpp | 2 + tests/test-tokenizer-0.cpp | 88 +++--- 5 files changed, 238 insertions(+), 141 deletions(-) diff --git a/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp b/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp index ecff95f9a..c140daed3 100644 --- a/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +++ b/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp @@ -201,7 +201,7 @@ static void print_sample_weights(TransformerWeights *w){ //////////////////////////////////////// ggml structs and functions required to load models, configs and save the model. -struct llama_vocab { +struct my_llama_vocab { using id = int32_t; using token = std::string; using ttype = llama_token_type; @@ -525,7 +525,7 @@ static std::string llama_escape_whitespaces(const std::string & text) { return out.str(); } -static void load_vocab(const char * filename, const Config * config, struct llama_vocab * vocab) { +static void load_vocab(const char * filename, const Config * config, struct my_llama_vocab * vocab) { if (is_ggml_file(filename)) { LOG_INF("%s: Loading vocabulary from gguf file %s\n", __func__, filename); struct ggml_context * ctx_data = NULL; @@ -583,13 +583,13 @@ static void load_vocab(const char * filename, const Config * config, struct llam const int n_vocab = config->vocab_size; /* uint32_t max_token_length = */ file.read_u32(); // unused vocab->id_to_token.resize(n_vocab); - for (llama_vocab::id id=0; idtoken_embedding_table -> model->tok_embeddings @@ -671,7 +671,7 @@ static void save_as_llama_model( std::vector tokens; std::vector scores; std::vector token_types; - for (const llama_vocab::token_data & token_data : vocab->id_to_token) { + for (const my_llama_vocab::token_data & token_data : vocab->id_to_token) { tokens.push_back(token_data.text.c_str()); scores.push_back(token_data.score); token_types.push_back(token_data.type); @@ -905,7 +905,7 @@ int main(int argc, char ** argv) { fclose(file); } - struct llama_vocab vocab; + struct my_llama_vocab vocab; load_vocab(params.fn_vocab_model, &config, &vocab); struct my_llama_model model; diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index 146d416f7..e4d844a73 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -50,7 +50,7 @@ struct naive_trie { res.first->second.insert(key + 1, len - 1, value); } } - std::pair get_longest_prefix(const char * key, size_t len, size_t offset = 0) { + std::pair get_longest_prefix(const char * key, size_t len, size_t offset = 0) const { if (len == 0 || offset == len) { return std::make_pair(key, offset); } @@ -79,6 +79,15 @@ struct naive_trie { // impl // +struct llm_tokenizer { + llm_tokenizer() {} + virtual ~llm_tokenizer() = default; +}; + +llama_vocab::~llama_vocab() { + delete tokenizer; +} + int llama_vocab::find_bpe_rank(const std::string & token_left, const std::string & token_right) const { GGML_ASSERT(token_left.find(' ') == std::string::npos); GGML_ASSERT(token_left.find('\n') == std::string::npos); @@ -187,10 +196,15 @@ struct llm_bigram_spm { size_t size; }; -struct llm_tokenizer_spm { - llm_tokenizer_spm(const llama_vocab & vocab) : vocab(vocab) {} +struct llm_tokenizer_spm : llm_tokenizer { + llm_tokenizer_spm(const llama_vocab & /*vocab*/) : llm_tokenizer() {} +}; + +struct llm_tokenizer_spm_session { + llm_tokenizer_spm_session(const llama_vocab & vocab) : vocab(vocab) {} void tokenize(const std::string & text, std::vector & output) { + // split string into utf8 chars int index = 0; size_t offs = 0; @@ -271,7 +285,7 @@ private: return; } - resegment(symbols[p->second.first], output); + resegment(symbols[p->second.first], output); resegment(symbols[p->second.second], output); } @@ -279,7 +293,6 @@ private: if (left == -1 || right == -1) { return; } - const std::string text = std::string(symbols[left].text, symbols[left].n + symbols[right].n); auto token = vocab.token_to_id.find(text); @@ -306,10 +319,11 @@ private: } const llama_vocab & vocab; + // currently unused + // const llm_tokenizer_spm * spm_tokenizer; std::vector symbols; llm_bigram_spm::queue work_queue; - std::map> rev_merge; }; @@ -352,8 +366,8 @@ struct llm_bigram_bpe { size_t size; }; -struct llm_tokenizer_bpe { - llm_tokenizer_bpe(const llama_vocab & vocab): vocab(vocab) { +struct llm_tokenizer_bpe : llm_tokenizer { + llm_tokenizer_bpe(const llama_vocab & vocab) : llm_tokenizer() { GGML_ASSERT(vocab.type == LLAMA_VOCAB_TYPE_BPE); switch (vocab.type_pre) { case LLAMA_VOCAB_PRE_TYPE_LLAMA3: @@ -476,7 +490,14 @@ struct llm_tokenizer_bpe { } } - void append(const llama_vocab::id token_id, std::vector & output) const { + std::vector regex_exprs; +}; + +struct llm_tokenizer_bpe_session { + llm_tokenizer_bpe_session(const llama_vocab & vocab) : vocab(vocab), + bpe_tokenizer(static_cast(vocab.tokenizer)) {} + + static void append(const llama_vocab::id token_id, std::vector & output) { output.push_back(token_id); } @@ -515,12 +536,11 @@ struct llm_tokenizer_bpe { void tokenize(const std::string & text, std::vector & output) { int final_prev_index = -1; - - const auto word_collection = unicode_regex_split(text, regex_exprs); + const auto word_collection = unicode_regex_split(text, bpe_tokenizer->regex_exprs); symbols_final.clear(); - for (auto & word : word_collection) { + for (const auto & word : word_collection) { work_queue = llm_bigram_bpe::queue(); symbols.clear(); @@ -623,7 +643,6 @@ private: if (left == -1 || right == -1) { return; } - std::string left_token = std::string(symbols[left].text, symbols[left].n); std::string right_token = std::string(symbols[right].text, symbols[right].n); @@ -647,12 +666,10 @@ private: } const llama_vocab & vocab; - - std::vector regex_exprs; + const llm_tokenizer_bpe * bpe_tokenizer; std::vector symbols; std::vector symbols_final; - llm_bigram_bpe::queue work_queue; }; @@ -660,15 +677,17 @@ private: // WPM tokenizer // -struct llm_tokenizer_wpm { - llm_tokenizer_wpm(const llama_vocab & vocab): vocab(vocab) {} +struct llm_tokenizer_wpm : llm_tokenizer { + llm_tokenizer_wpm(const llama_vocab & /*vocab*/) : llm_tokenizer() {} +}; - void tokenize(const std::string & text, std::vector & output) const { +struct llm_tokenizer_wpm_session { + llm_tokenizer_wpm_session(const llama_vocab & vocab) : vocab(vocab) {} + + void tokenize(const std::string & text, std::vector & output) { const auto & token_map = vocab.token_to_id; - // normalize and split by whitespace std::vector words = preprocess(text); - // bos token prepended already // find the longest tokens that form the words @@ -713,7 +732,7 @@ struct llm_tokenizer_wpm { } // TODO: reduce string copies by using cpts_offs array - std::vector preprocess(const std::string & text) const { + static std::vector preprocess(const std::string & text) { const std::vector cpts_nfd = unicode_cpts_normalize_nfd(unicode_cpts_from_utf8(text)); std::vector words(1, ""); @@ -765,15 +784,18 @@ struct llm_tokenizer_wpm { //(cpt >= 0xFF00 && cpt <= 0xFFEF); } +private: const llama_vocab & vocab; + // currently unused + // const llm_tokenizer_wpm * wpm_tokenizer; }; // // UGM tokenizer // -struct llm_tokenizer_ugm { - llm_tokenizer_ugm(const llama_vocab & vocab) : vocab(vocab) { +struct llm_tokenizer_ugm : llm_tokenizer { + llm_tokenizer_ugm(const llama_vocab & vocab) : llm_tokenizer() { if (vocab.precompiled_charsmap.size() > 0) { size_t charsmap_offset = 0; @@ -819,6 +841,30 @@ struct llm_tokenizer_ugm { unknown_token_score = min_score - unknown_token_score_penalty; } + // escaped space symbol - U+2581 (Lower One Eighth Block) + const std::string escaped_space = "\xE2\x96\x81"; + + const char * prefix_replacements = NULL; + size_t prefix_replacements_size = 0; + + const uint32_t * xcda_array = NULL; + size_t xcda_array_size = 0; + + struct naive_trie user_defined_token_matcher; + + float min_score = FLT_MAX; + float max_score = -FLT_MAX; + + float unknown_token_score_penalty = 10.0; + float unknown_token_score; + + struct naive_trie token_matcher; +}; + +struct llm_tokenizer_ugm_session { + llm_tokenizer_ugm_session(const llama_vocab & vocab) : vocab(vocab), + ugm_tokenizer(static_cast(vocab.tokenizer)) {} + /* This implementation is based on SentencePiece optimized Viterbi algorithm for * unigram language models. The general idea is to: * - move along the input sequence in steps of one UTF code point, @@ -857,7 +903,7 @@ struct llm_tokenizer_ugm { // traverse the token matcher trie to find a matching token bool single_codepoint_token_found = false; const struct best_tokenization & current_best = tokenization_results[input_offset]; - const struct naive_trie * node = token_matcher.traverse(normalized[prefix_offset++]); + const struct naive_trie * node = ugm_tokenizer->token_matcher.traverse(normalized[prefix_offset++]); while (prefix_offset <= input_len && node != NULL) { // check if we found valid token in prefix @@ -887,7 +933,7 @@ struct llm_tokenizer_ugm { // if we didn't find a valid token corresponding to the whole UTF code point // then use unknown token as the tokenization of this UTF code point if (!single_codepoint_token_found) { - const double challenger_score = current_best.score_sum + unknown_token_score; + const double challenger_score = current_best.score_sum + ugm_tokenizer->unknown_token_score; prefix_offset = input_offset + n_utf8_code_units; struct best_tokenization & current_champ = tokenization_results[prefix_offset]; if (challenger_score > current_champ.score_sum) { @@ -919,7 +965,6 @@ struct llm_tokenizer_ugm { } private: - const llama_vocab & vocab; // helper structure for returning normalization results struct normalization_result { @@ -932,7 +977,7 @@ private: normalized->clear(); normalized->reserve(input.size() * 3); - const std::string space = vocab.tokenizer_escape_whitespaces ? escaped_space : " "; + const std::string space = vocab.tokenizer_escape_whitespaces ? ugm_tokenizer->escaped_space : " "; bool shall_prepend_space = !vocab.tokenizer_treat_whitespace_as_suffix && vocab.tokenizer_add_space_prefix; bool shall_append_space = vocab.tokenizer_treat_whitespace_as_suffix && vocab.tokenizer_add_space_prefix; @@ -1014,13 +1059,21 @@ private: size_t xcda_array_size; }; + // this structure stores the best tokenization so far at input_offset + struct best_tokenization { + llama_token token_id; + size_t input_offset; + float score_sum; + }; + struct normalization_result normalize_prefix(const std::string & input, size_t input_offset) { if (input_offset == input.size()) { return { &input[input_offset], 0, 0 }; } // if input prefix matches some user-defined token return this token as normalization result - auto user_defined_token_match = user_defined_token_matcher.get_longest_prefix(&input[input_offset], input.size() - input_offset); + auto user_defined_token_match = + ugm_tokenizer->user_defined_token_matcher.get_longest_prefix(&input[input_offset], input.size() - input_offset); if (user_defined_token_match.second > 0) { return { &input[input_offset], user_defined_token_match.second, user_defined_token_match.second }; } @@ -1028,8 +1081,8 @@ private: size_t longest_prefix_length = 0; size_t longest_prefix_offset = 0; - if (xcda_array_size > 0) { - struct xcda_array_view xcda_view(xcda_array, xcda_array_size); + if (ugm_tokenizer->xcda_array_size > 0) { + struct xcda_array_view xcda_view(ugm_tokenizer->xcda_array, ugm_tokenizer->xcda_array_size); // Find the longest normalized sequence matching the input prefix by walking // the XOR-compressed compact double array (XCDA) starting from the root node @@ -1065,50 +1118,27 @@ private: if (longest_prefix_length > 0) { // we have a match, so return the replacement sequence - if (longest_prefix_offset >= prefix_replacements_size) { + if (longest_prefix_offset >= ugm_tokenizer->prefix_replacements_size) { throw std::runtime_error("Index out of array bounds in precompiled charsmap!"); } - const char * prefix_replacement = &prefix_replacements[longest_prefix_offset]; + const char * prefix_replacement = &(ugm_tokenizer->prefix_replacements)[longest_prefix_offset]; return { prefix_replacement, strlen(prefix_replacement), longest_prefix_length }; - } else { - // check if the input prefix contains a valid sequence of UTF-8 code units - try { - // if yes, return this sequence unmodified - size_t prefix_offset = input_offset; - unicode_cpt_from_utf8(input, prefix_offset); - return { &input[input_offset], prefix_offset - input_offset, prefix_offset - input_offset }; - } catch (std::invalid_argument & /*ex*/) { - // if no, consume 1 byte and return U+FFFD - REPLACEMENT CHARACTER - return { "\xEF\xBF\xBD", 3, 1 }; - } + } + + // check if the input prefix contains a valid sequence of UTF-8 code units + try { + // if yes, return this sequence unmodified + size_t prefix_offset = input_offset; + unicode_cpt_from_utf8(input, prefix_offset); + return { &input[input_offset], prefix_offset - input_offset, prefix_offset - input_offset }; + } catch (std::invalid_argument & /*ex*/) { + // if no, consume 1 byte and return U+FFFD - REPLACEMENT CHARACTER + return { "\xEF\xBF\xBD", 3, 1 }; } } - // escaped space symbol - U+2581 (Lower One Eighth Block) - const std::string escaped_space = "\xE2\x96\x81"; - - const char * prefix_replacements = NULL; - size_t prefix_replacements_size = 0; - - const uint32_t * xcda_array = NULL; - size_t xcda_array_size = 0; - - struct naive_trie user_defined_token_matcher; - - // this structure stores the best tokenization so far at input_offset - struct best_tokenization { - llama_token token_id; - size_t input_offset; - float score_sum; - }; - - float min_score = FLT_MAX; - float max_score = -FLT_MAX; - - float unknown_token_score_penalty = 10.0; - float unknown_token_score; - - struct naive_trie token_matcher; + const llama_vocab & vocab; + const llm_tokenizer_ugm * ugm_tokenizer; }; // @@ -1169,8 +1199,8 @@ static std::vector llama_unescape_rwkv_token(const std::string & escape return output; } -struct llm_tokenizer_rwkv { - llm_tokenizer_rwkv(const llama_vocab & vocab): vocab(vocab) { +struct llm_tokenizer_rwkv : llm_tokenizer { + llm_tokenizer_rwkv(const llama_vocab & vocab) : llm_tokenizer() { // RWKV supports arbitrary byte tokens, but the vocab struct only supports string tokens. // For now, we decode the vocab here into the lookup we'll use for tokenization. @@ -1182,11 +1212,17 @@ struct llm_tokenizer_rwkv { } } + struct naive_trie token_matcher; +}; + +struct llm_tokenizer_rwkv_session { + llm_tokenizer_rwkv_session(const llama_vocab & vocab) : vocab(vocab), + rwkv_tokenizer(static_cast(*vocab.tokenizer)) {} + void tokenize(const std::string & text, std::vector & output) { uint32_t position = 0; - while (position < text.size()) { - const struct naive_trie * node = token_matcher.traverse(text[position]); + const struct naive_trie * node = rwkv_tokenizer.token_matcher.traverse(text[position]); if (node == NULL) { // no matching token found, add unknown token output.push_back(vocab.special_unk_id); @@ -1211,11 +1247,33 @@ struct llm_tokenizer_rwkv { } } +private: const llama_vocab & vocab; - - struct naive_trie token_matcher; + const llm_tokenizer_rwkv & rwkv_tokenizer; }; +void llama_vocab::init_tokenizer() { + switch (type) { + case LLAMA_VOCAB_TYPE_SPM: + tokenizer = new llm_tokenizer_spm(*this); + break; + case LLAMA_VOCAB_TYPE_BPE: + tokenizer = new llm_tokenizer_bpe(*this); + break; + case LLAMA_VOCAB_TYPE_WPM: + tokenizer = new llm_tokenizer_wpm(*this); + break; + case LLAMA_VOCAB_TYPE_UGM: + tokenizer = new llm_tokenizer_ugm(*this); + break; + case LLAMA_VOCAB_TYPE_RWKV: + tokenizer = new llm_tokenizer_rwkv(*this); + break; + default: + GGML_ABORT("unsupported vocab type"); + } +} + // // (de-) tokenize // @@ -1277,7 +1335,7 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list< // if a fragment is text ( not yet processed ) if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { - auto & raw_text = fragment.raw_text; + const auto & raw_text = fragment.raw_text; auto raw_text_base_offset = fragment.offset; auto raw_text_base_length = fragment.length; @@ -1376,7 +1434,13 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list< } } -std::vector llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool add_special, bool parse_special) { +std::vector llama_tokenize_internal( + const llama_vocab & vocab, + std::string raw_text, + bool add_special, + bool parse_special) { + GGML_ASSERT(vocab.tokenizer && "Tokenizer not initialized. Call llama_vocab::init_tokenizer() first."); + std::vector output; std::forward_list fragment_buffer; @@ -1413,9 +1477,9 @@ std::vector llama_tokenize_internal(const llama_vocab & vocab, #ifdef PRETOKENIZERDEBUG LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str()); #endif - llm_tokenizer_spm tokenizer(vocab); llama_escape_whitespace(raw_text); - tokenizer.tokenize(raw_text, output); + llm_tokenizer_spm_session session(vocab); + session.tokenize(raw_text, output); is_prev_special = false; } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN) output.push_back(fragment.token); @@ -1437,10 +1501,11 @@ std::vector llama_tokenize_internal(const llama_vocab & vocab, } break; case LLAMA_VOCAB_TYPE_BPE: { - llm_tokenizer_bpe tokenizer(vocab); - + llm_tokenizer_bpe_session session(vocab); + // it calls some other methods that are not exist in llm_tokenizer, + // here just cast it to bpe tokenizer object if (add_special) { - tokenizer.append_bos(output); + session.append_bos(output); } for (const auto & fragment : fragment_buffer) { if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { @@ -1449,15 +1514,15 @@ std::vector llama_tokenize_internal(const llama_vocab & vocab, #ifdef PRETOKENIZERDEBUG LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str()); #endif - tokenizer.tokenize(raw_text, output); + session.tokenize(raw_text, output); } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN) - tokenizer.append(fragment.token, output); + session.append(fragment.token, output); } } if (add_special) { - tokenizer.append_eos(output); - tokenizer.check_double_bos_eos(output); + session.append_eos(output); + session.check_double_bos_eos(output); } } break; case LLAMA_VOCAB_TYPE_WPM: @@ -1467,7 +1532,7 @@ std::vector llama_tokenize_internal(const llama_vocab & vocab, output.push_back(vocab.special_cls_id); } - llm_tokenizer_wpm tokenizer(vocab); + llm_tokenizer_wpm_session session(vocab); for (const auto & fragment : fragment_buffer) { if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { @@ -1476,7 +1541,7 @@ std::vector llama_tokenize_internal(const llama_vocab & vocab, #ifdef PRETOKENIZERDEBUG LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str()); #endif - tokenizer.tokenize(raw_text, output); + session.tokenize(raw_text, output); } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN) output.push_back(fragment.token); } @@ -1489,12 +1554,11 @@ std::vector llama_tokenize_internal(const llama_vocab & vocab, } break; case LLAMA_VOCAB_TYPE_UGM: { - llm_tokenizer_ugm tokenizer(vocab); - if (add_special && vocab.tokenizer_add_bos != 0) { GGML_ASSERT(vocab.special_bos_id != -1); output.push_back(vocab.special_bos_id); } + llm_tokenizer_ugm_session session(vocab); for (const auto & fragment : fragment_buffer) { if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { @@ -1502,7 +1566,7 @@ std::vector llama_tokenize_internal(const llama_vocab & vocab, #ifdef PRETOKENIZERDEBUG LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str()); #endif - tokenizer.tokenize(raw_text, output); + session.tokenize(raw_text, output); } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN) output.push_back(fragment.token); } @@ -1522,6 +1586,7 @@ std::vector llama_tokenize_internal(const llama_vocab & vocab, } break; case LLAMA_VOCAB_TYPE_RWKV: { + llm_tokenizer_rwkv_session session(vocab); for (const auto & fragment : fragment_buffer) { if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length); @@ -1530,8 +1595,7 @@ std::vector llama_tokenize_internal(const llama_vocab & vocab, LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str()); #endif - llm_tokenizer_rwkv tokenizer(vocab); - tokenizer.tokenize(raw_text, output); + session.tokenize(raw_text, output); } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN) output.push_back(fragment.token); } @@ -1644,13 +1708,13 @@ llama_token llama_token_eom_impl(const struct llama_vocab & vocab) { } int32_t llama_tokenize_impl( - const struct llama_vocab & vocab, - const char * text, - int32_t text_len, - llama_token * tokens, - int32_t n_tokens_max, - bool add_special, - bool parse_special) { + const struct llama_vocab & vocab, + const char * text, + int32_t text_len, + llama_token * tokens, + int32_t n_tokens_max, + bool add_special, + bool parse_special) { auto res = llama_tokenize_internal(vocab, std::string(text, text_len), add_special, parse_special); if (n_tokens_max < (int) res.size()) { // LLAMA_LOG_ERROR("%s: too many tokens\n", __func__); @@ -1775,6 +1839,8 @@ int32_t llama_detokenize_impl( int32_t text_len_max, bool remove_special, bool unparse_special) { + GGML_ASSERT(vocab.tokenizer && "Tokenizer not initialized. Call llama_vocab::init_tokenizer() first."); + int32_t avail = text_len_max; int32_t total = 0; diff --git a/src/llama-vocab.h b/src/llama-vocab.h index cc46f642b..069bdc423 100644 --- a/src/llama-vocab.h +++ b/src/llama-vocab.h @@ -8,6 +8,8 @@ #include #include +struct llm_tokenizer; + struct llama_vocab { using id = llama_token; using token = std::string; @@ -65,7 +67,14 @@ struct llama_vocab { std::vector precompiled_charsmap; + llm_tokenizer * tokenizer = nullptr; + + llama_vocab() = default; + ~llama_vocab(); + int find_bpe_rank(const std::string & token_left, const std::string & token_right) const; + + void init_tokenizer(); }; // diff --git a/src/llama.cpp b/src/llama.cpp index f450eaf9d..44afb31d7 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -6464,6 +6464,8 @@ static void llm_load_vocab( } GGML_ASSERT(vocab.id_to_token.size() == vocab.token_to_id.size()); + vocab.init_tokenizer(); + // determine the newline token: LLaMA "<0x0A>" == 10 == '\n', Falcon 193 == '\n' if (vocab.type == LLAMA_VOCAB_TYPE_SPM) { // For Fill-In-the-Middle (FIM)/infill models which where converted diff --git a/tests/test-tokenizer-0.cpp b/tests/test-tokenizer-0.cpp index d3d21331b..4d49850c9 100644 --- a/tests/test-tokenizer-0.cpp +++ b/tests/test-tokenizer-0.cpp @@ -7,6 +7,7 @@ #include #include #include +#include //static const std::map> & k_tests() { // static std::map> _k_tests = { @@ -194,45 +195,64 @@ int main(int argc, char **argv) { const bool add_special = false; - for (const auto & test_kv : k_tests) { - const std::vector res = llama_tokenize(ctx, test_kv.first, add_special, false); + // multi-threaded tokenization + const int nthread = std::thread::hardware_concurrency(); + std::vector threads(nthread); - printf("\n"); - printf("src: '%s'\n", test_kv.first.c_str()); - printf("res: '%s'\n", llama_detokenize(ctx, res).c_str()); - printf("tok: "); - for (const auto & tok : res) { - printf("%d ", tok); - } - printf("\n"); + for (int i = 0; i < nthread; i++) { + threads[i] = std::thread([&, i]() { + for (const auto & test_kv : k_tests) { + const std::vector res = llama_tokenize(ctx, test_kv.first, add_special, false); - bool correct = res.size() == test_kv.second.size(); - for (int i = 0; i < (int) res.size() && correct; ++i) { - if (test_kv.second[i] != res[i]) { - correct = false; + // here only print the result of the first thread + // because the other threads are running the same tests + if (i != 0) { + continue; + } + + printf("\n"); + printf("src: '%s'\n", test_kv.first.c_str()); + printf("res: '%s'\n", llama_detokenize(ctx, res).c_str()); + printf("tok: "); + for (const auto & tok : res) { + printf("%d ", tok); + } + printf("\n"); + + bool correct = res.size() == test_kv.second.size(); + for (int i = 0; i < (int) res.size() && correct; ++i) { + if (test_kv.second[i] != res[i]) { + correct = false; + } + } + + if (!correct) { + fprintf(stderr, "%s : failed test: '%s'\n", __func__, test_kv.first.c_str()); + fprintf(stderr, "%s : detokenized to: '%s' instead of '%s'\n", __func__, + llama_detokenize(ctx, res).c_str(), + llama_detokenize(ctx, test_kv.second).c_str()); + fprintf(stderr, "%s : expected tokens: ", __func__); + for (const auto & t : test_kv.second) { + fprintf(stderr, "%6d '%s', ", t, llama_token_to_piece(ctx, t).c_str()); + } + fprintf(stderr, "\n"); + fprintf(stderr, "%s : got tokens: ", __func__); + for (const auto & t : res) { + fprintf(stderr, "%6d '%s', ", t, llama_token_to_piece(ctx, t).c_str()); + } + fprintf(stderr, "\n"); + + success = false; + } } - } - - if (!correct) { - fprintf(stderr, "%s : failed test: '%s'\n", __func__, test_kv.first.c_str()); - fprintf(stderr, "%s : detokenized to: '%s' instead of '%s'\n", __func__, - llama_detokenize(ctx, res).c_str(), - llama_detokenize(ctx, test_kv.second).c_str()); - fprintf(stderr, "%s : expected tokens: ", __func__); - for (const auto & t : test_kv.second) { - fprintf(stderr, "%6d '%s', ", t, llama_token_to_piece(ctx, t).c_str()); - } - fprintf(stderr, "\n"); - fprintf(stderr, "%s : got tokens: ", __func__); - for (const auto & t : res) { - fprintf(stderr, "%6d '%s', ", t, llama_token_to_piece(ctx, t).c_str()); - } - fprintf(stderr, "\n"); - - success = false; - } + }); } + for (int i = 0; i < nthread; i++) { + threads[i].join(); + } + + // single threaded tokenization if (!fname_text.empty()) { fprintf(stderr, "%s : tokenizing: '%s'\n", __func__, fname_text.c_str()); From 739842703e32cd43443c45e0b4f6647cc4e6b3d6 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 28 Sep 2024 15:13:21 +0300 Subject: [PATCH 08/33] llama : add comment about thread-safety [no ci] (#9449) --- include/llama.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/include/llama.h b/include/llama.h index caef0bfff..4ea8a2c2b 100644 --- a/include/llama.h +++ b/include/llama.h @@ -911,6 +911,8 @@ extern "C" { // // Tokenization // + // The API is thread-safe. + // /// @details Convert the provided text into tokens. /// @param tokens The tokens pointer must be large enough to hold the resulting tokens. From 1b2f992cd2cff0b69e5abe78bb8888d51ed19d67 Mon Sep 17 00:00:00 2001 From: slaren Date: Sat, 28 Sep 2024 14:32:46 +0200 Subject: [PATCH 09/33] test-backend-ops : use flops for some performance tests (#9657) * test-backend-ops : use flops for some performance tests - parallelize tensor quantization - use a different set of cases for performance and correctness tests - run each test for at least one second --- tests/test-backend-ops.cpp | 411 +++++++++++++++++++------------------ 1 file changed, 214 insertions(+), 197 deletions(-) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 9a96cfc4c..d2cfe06b5 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -32,63 +32,52 @@ #include #include #include +#include #include static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f) { - // static RNG initialization (revisit if n_threads stops being constant) - static const size_t n_threads = std::thread::hardware_concurrency(); - static std::vector generators = []() { - std::random_device rd; - std::vector vec; - vec.reserve(n_threads); - //for (size_t i = 0; i < n_threads; i++) { vec.emplace_back(1234 + i); } // fixed seed - for (size_t i = 0; i < n_threads; i++) { vec.emplace_back(rd()); } - return vec; - }(); + size_t nels = ggml_nelements(tensor); + std::vector data(nels); + { + // parallel initialization + static const size_t n_threads = std::thread::hardware_concurrency(); + // static RNG initialization (revisit if n_threads stops being constant) + static std::vector generators = []() { + std::random_device rd; + std::vector vec; + vec.reserve(n_threads); + //for (size_t i = 0; i < n_threads; i++) { vec.emplace_back(1234 + i); } // fixed seed + for (size_t i = 0; i < n_threads; i++) { vec.emplace_back(rd()); } + return vec; + }(); - size_t size = ggml_nelements(tensor); - std::vector data(size); + auto init_thread = [&](size_t ith, size_t start, size_t end) { + std::uniform_real_distribution distribution(min, max); + auto & gen = generators[ith]; + for (size_t i = start; i < end; i++) { + data[i] = distribution(gen); + } + }; - auto init_thread = [&](size_t ith, size_t start, size_t end) { - std::uniform_real_distribution distribution(min, max); - for (size_t i = start; i < end; i++) { - data[i] = distribution(generators[ith]); + std::vector> tasks; + tasks.reserve(n_threads); + for (size_t i = 0; i < n_threads; i++) { + size_t start = i*nels/n_threads; + size_t end = (i+1)*nels/n_threads; + tasks.push_back(std::async(std::launch::async, init_thread, i, start, end)); } - }; - - std::vector threads; - threads.reserve(n_threads); - for (size_t i = 0; i < n_threads; i++) { - size_t start = i*size/n_threads; - size_t end = (i+1)*size/n_threads; - threads.emplace_back(init_thread, i, start, end); - } - for (auto & t : threads) { - t.join(); - } - -#if 0 - const char * val_str = getenv("GGML_TEST_EPS"); - float val = 1e-9f; - if (val_str != nullptr) { - val = std::stof(val_str); - printf("GGML_TEST_EPS=%e\n", val); - } - - // test quantization with very small values that may result in nan scales due to division by zero - if (ggml_is_quantized(tensor->type)) { - for (int i = 0; i < 256; i++) { - data[i] = val; + for (auto & t : tasks) { + t.get(); } } -#endif if (tensor->type == GGML_TYPE_F32 || tensor->type == GGML_TYPE_I32) { - ggml_backend_tensor_set(tensor, data.data(), 0, size * sizeof(float)); + ggml_backend_tensor_set(tensor, data.data(), 0, nels * sizeof(float)); } else if (ggml_is_quantized(tensor->type) || tensor->type == GGML_TYPE_F16 || tensor->type == GGML_TYPE_BF16) { - GGML_ASSERT(size % ggml_blck_size(tensor->type) == 0); - std::vector dataq(ggml_row_size(tensor->type, size)); - std::vector imatrix(tensor->ne[0], 1.0f); // dummy importance matrix + GGML_ASSERT(nels % ggml_blck_size(tensor->type) == 0); + + // dummy importance matrix + std::vector imatrix(tensor->ne[0], 1.0f); const float * im = imatrix.data(); if (!ggml_quantize_requires_imatrix(tensor->type)) { // when the imatrix is optional, we want to test both quantization with and without imatrix @@ -98,15 +87,31 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m } } - ggml_quantize_chunk(tensor->type, data.data(), dataq.data(), 0, size/tensor->ne[0], tensor->ne[0], im); - GGML_ASSERT(ggml_validate_row_data(tensor->type, dataq.data(), dataq.size())); - // TODO: other cases - //#pragma omp parallel for - //for (int i = 0; i < tensor->ne[1]; i++) { - // ggml_quantize_chunk(tensor->type, data.data(), dataq.data(), - // i * tensor->ne[0], 1, tensor->ne[0], im); - //} + std::vector dataq(ggml_row_size(tensor->type, nels)); + { + // parallel quantization by block + size_t blck_size = ggml_blck_size(tensor->type); + size_t n_blocks = nels / blck_size; + auto quantize_thread = [&](size_t start, size_t end) { + ggml_quantize_chunk(tensor->type, data.data(), dataq.data(), + start * blck_size, end - start, blck_size, im); + }; + + const size_t min_blocks_per_thread = 1; + const size_t n_threads = std::min(std::thread::hardware_concurrency()/2, + std::max(1, n_blocks / min_blocks_per_thread)); + std::vector> tasks; + tasks.reserve(n_threads); + for (size_t i = 0; i < n_threads; i++) { + size_t start = i*n_blocks/n_threads; + size_t end = (i+1)*n_blocks/n_threads; + tasks.push_back(std::async(std::launch::async, quantize_thread, start, end)); + } + for (auto & t : tasks) { + t.get(); + } + } ggml_backend_tensor_set(tensor, dataq.data(), 0, dataq.size()); } else if (tensor->type == GGML_TYPE_I8 || tensor->type == GGML_TYPE_I16 || tensor->type == GGML_TYPE_I32) { // This is going to create some weird integers though. @@ -160,60 +165,6 @@ static std::vector tensor_to_float(const ggml_tensor * t) { return tv; } -/* -static double cosine_similarity(const float * v1, const float * v2, size_t n) { - double dot = 0.0; - double mag1 = 0.0; - double mag2 = 0.0; - - for (size_t i = 0; i < n; i++) { - if (std::isnan(v1[i]) || std::isnan(v2[i])) { - return -1.0f; - } - if (std::isinf(v1[i]) && std::isinf(v2[i])) { - continue; - } - dot += v1[i]*v2[i]; - mag1 += v1[i]*v1[i]; - mag2 += v2[i]*v2[i]; - } - - return dot/sqrt(mag1*mag2); -} - -static float distance(const float * v1, const float * v2, size_t n) { - double d = 0.0; - - for (size_t i = 0; i < n; i++) { - if (std::isnan(v1[i]) || std::isnan(v2[i])) { - return INFINITY; - } - if (std::isinf(v1[i]) && std::isinf(v2[i])) { - continue; - } - d += (v1[i] - v2[i])*(v1[i] - v2[i]); - } - - return sqrt(d); -} - -static float vec_len(const float * v, size_t n) { - double d = 0.0; - - for (size_t i = 0; i < n; i++) { - if (std::isnan(v[i])) { - return INFINITY; - } - if (std::isinf(v[i])) { - continue; - } - d += v[i]*v[i]; - } - - return sqrt(d); -} -*/ - // normalized mean squared error = mse(a, b) / mse(a, 0) static double nmse(const float * a, const float * b, size_t n) { double mse_a_b = 0.0; @@ -264,7 +215,6 @@ static double mean_abs_asymm(const float * a, const float * b, const size_t n, c } // utils for printing the variables of the test cases -#define VAR_TO_STR(x) (#x "=" + var_to_str(x)) template static std::string var_to_str(const T & x) { @@ -297,10 +247,6 @@ static std::string var_to_str(const std::array & x) { return s; } -//static std::string var_to_str(ggml_unary_op unary_op) { -// return ggml_unary_op_name(unary_op); -//} - static std::string var_to_str(ggml_type type) { return ggml_type_name(type); } @@ -313,6 +259,8 @@ static std::string var_to_str(ggml_op_pool pool) { } } +#define VAR_TO_STR(x) (#x "=" + var_to_str(x)) + #define VARS_TO_STR1(a) VAR_TO_STR(a) #define VARS_TO_STR2(a, b) VAR_TO_STR(a) + "," + VAR_TO_STR(b) #define VARS_TO_STR3(a, b, c) VAR_TO_STR(a) + "," + VARS_TO_STR2(b, c) @@ -370,13 +318,13 @@ struct test_case { return 1e-4; } - virtual float grad_eps(){ + virtual float grad_eps() { return 1e-1f; } // If false, estimate gradient with 2 points, neglects 3rd order derivative and higher. // If true, estimate gradient with 4 points, neglects 5th order derivative and higher. - virtual bool grad_precise(){ + virtual bool grad_precise() { return false; } @@ -409,6 +357,11 @@ struct test_case { return size; } + virtual uint64_t op_flops(ggml_tensor * t) { + GGML_UNUSED(t); + return 0; + } + ggml_cgraph * gf = nullptr; ggml_cgraph * gb = nullptr; @@ -651,12 +604,11 @@ struct test_case { } // align while also leaving some margin for variations in parameters - int align = 20; + int align = 8; int last = (len + align - 1) / align * align; if (last - len < 5) { last += align; } - last = std::max(last, 60); printf("%*s", last - len, ""); // allocate @@ -677,9 +629,25 @@ struct test_case { // warmup run ggml_backend_graph_compute(backend, gf); + // determine number of runs + int n_runs; + if (op_flops(out) > 0) { + // based on flops + const uint64_t GFLOP = 1000 * 1000 * 1000; + const uint64_t target_flops_cpu = 8ULL * GFLOP; + const uint64_t target_flops_gpu = 100ULL * GFLOP; + uint64_t target_flops = ggml_backend_is_cpu(backend) ? target_flops_cpu : target_flops_gpu; + n_runs = std::min(ggml_graph_size(gf) - ggml_graph_n_nodes(gf), target_flops / op_flops(out)) + 1; + } else { + // based on memory size + const size_t GB = 1ULL << 30; + const size_t target_size_cpu = 8 * GB; + const size_t target_size_gpu = 32 * GB; + size_t target_size = ggml_backend_is_cpu(backend) ? target_size_cpu : target_size_gpu; + n_runs = std::min(ggml_graph_size(gf) - ggml_graph_n_nodes(gf), target_size / op_size(out)) + 1; + } + // duplicate the op - size_t target_size = ggml_backend_is_cpu(backend) ? 1ULL << 33 : 1ULL << 35; // 8 GB CPU, 32 GB GPU - int n_runs = std::min((size_t) ggml_graph_size(gf) - ggml_graph_n_nodes(gf), target_size / op_size(out)) + 1; for (int i = 1; i < n_runs; i++) { ggml_graph_add_node(gf, out); } @@ -706,17 +674,47 @@ struct test_case { // run ggml_backend_synchronize(backend); - int64_t start_time = ggml_time_us(); - ggml_backend_graph_compute(backend, gf); - ggml_backend_synchronize(backend); - int64_t end_time = ggml_time_us(); - double time_us = end_time - start_time; + int64_t total_time_us = 0; + int total_runs = 0; + do { + int64_t start_time = ggml_time_us(); + ggml_backend_graph_compute(backend, gf); + ggml_backend_synchronize(backend); + int64_t end_time = ggml_time_us(); - printf(" %5d runs - %8.2f us/run - %8zu kB/run - \033[1;34m%7.2f GB/s\033[0m\n", - n_runs, - time_us / n_runs, - op_size(out) / 1024, - mem / (time_us/1e6) / 1024.0 / 1024.0 / 1024.0); + total_time_us += end_time - start_time; + total_runs += n_runs; + } while (total_time_us < 1000*1000); // run for at least 1 second + + printf(" %8d runs - %8.2f us/run - ", + total_runs, + (double)total_time_us / total_runs); + + if (op_flops(out) > 0) { + double flops_per_sec = (op_flops(out) * total_runs) / (total_time_us / 1e6); + auto format_flops = [](double flops) -> std::string { + char buf[256]; + if (flops >= 1e12) { + snprintf(buf, sizeof(buf), "%6.2f TFLOP", flops / 1e12); + } else if (flops >= 1e9) { + snprintf(buf, sizeof(buf), "%6.2f GFLOP", flops / 1e9); + } else if (flops >= 1e6) { + snprintf(buf, sizeof(buf), "%6.2f MFLOP", flops / 1e6); + } else { + snprintf(buf, sizeof(buf), "%6.2f KFLOP", flops / 1e3); + } + return buf; + }; + printf("%s/run - \033[1;34m%sS\033[0m", + format_flops(op_flops(out)).c_str(), + format_flops(flops_per_sec).c_str()); + + } else { + printf("%8zu kB/run - \033[1;34m%7.2f GB/s\033[0m", + op_size(out) / 1024, + mem / (total_time_us / 1e6) / 1024.0 / 1024.0 / 1024.0); + } + printf("\n"); ggml_backend_buffer_free(buf); @@ -1591,13 +1589,9 @@ struct test_mul_mat : public test_case { return 5e-4; } - size_t op_size(ggml_tensor * t) override { - size_t a = ggml_nbytes(t->src[0]) * n * nr[0] * nr[1]; - size_t b = ggml_nbytes(t->src[1]) * m; - size_t c = ggml_nbytes(t); - return a + b + c; - + uint64_t op_flops(ggml_tensor * t) override { GGML_UNUSED(t); + return 2 * m * n * k * bs[0] * nr[0] * bs[1] * nr[1]; } test_mul_mat(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32, @@ -1641,13 +1635,9 @@ struct test_mul_mat_id : public test_case { return 5e-4; } - size_t op_size(ggml_tensor * t) override { - size_t a = ggml_nbytes(t->src[2]) * n; - size_t b = ggml_nbytes(t->src[1]) * m; - size_t c = ggml_nbytes(t); - return a + b + c; - + uint64_t op_flops(ggml_tensor * t) override { GGML_UNUSED(t); + return 2 * m * k * n * n_used; } test_mul_mat_id(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32, @@ -3163,47 +3153,46 @@ struct test_falcon : public test_llm { // ########################################### // ## Section 3: GGML Op Test Instantiation ## // ########################################### +static const ggml_type all_types[] = { + GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16, + GGML_TYPE_Q4_0, GGML_TYPE_Q4_1, + GGML_TYPE_Q5_0, GGML_TYPE_Q5_1, + GGML_TYPE_Q8_0, + GGML_TYPE_Q2_K, GGML_TYPE_Q3_K, + GGML_TYPE_Q4_K, GGML_TYPE_Q5_K, + GGML_TYPE_Q6_K, + // GGML_TYPE_TQ1_0, GGML_TYPE_TQ2_0, // TODO: implement for all backends + GGML_TYPE_IQ2_XXS, GGML_TYPE_IQ2_XS, GGML_TYPE_IQ2_S, + GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ1_S, GGML_TYPE_IQ1_M, + GGML_TYPE_IQ4_NL, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS, +}; +static const ggml_type base_types[] = { + GGML_TYPE_F32, GGML_TYPE_F16, + GGML_TYPE_Q4_0, + GGML_TYPE_Q4_K, + GGML_TYPE_IQ2_XXS +}; -static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op_name) { +static const ggml_type other_types[] = { + GGML_TYPE_Q4_1, + GGML_TYPE_Q5_0, GGML_TYPE_Q5_1, + GGML_TYPE_Q8_0, + GGML_TYPE_Q2_K, GGML_TYPE_Q3_K, + GGML_TYPE_Q5_K, + GGML_TYPE_Q6_K, + // GGML_TYPE_TQ1_0, GGML_TYPE_TQ2_0, // TODO: implement for all backends + GGML_TYPE_IQ2_XS, GGML_TYPE_IQ2_S, + GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ1_S, GGML_TYPE_IQ1_M, + GGML_TYPE_IQ4_NL, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS, + GGML_TYPE_BF16, +}; + +// Test cases for evaluation: should try to cover edge cases while using small input sizes to keep the runtime low +static std::vector> make_test_cases_eval() { std::vector> test_cases; std::default_random_engine rng(0); - const ggml_type all_types[] = { - GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16, - GGML_TYPE_Q4_0, GGML_TYPE_Q4_1, - GGML_TYPE_Q5_0, GGML_TYPE_Q5_1, - GGML_TYPE_Q8_0, - GGML_TYPE_Q2_K, GGML_TYPE_Q3_K, - GGML_TYPE_Q4_K, GGML_TYPE_Q5_K, - GGML_TYPE_Q6_K, - // GGML_TYPE_TQ1_0, GGML_TYPE_TQ2_0, // TODO: implement for all backends - GGML_TYPE_IQ2_XXS, GGML_TYPE_IQ2_XS, GGML_TYPE_IQ2_S, - GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ1_S, GGML_TYPE_IQ1_M, - GGML_TYPE_IQ4_NL, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS, - }; - - const ggml_type base_types[] = { - GGML_TYPE_F32, GGML_TYPE_F16, - GGML_TYPE_Q4_0, - GGML_TYPE_Q4_K, - GGML_TYPE_IQ2_XXS - }; - - const ggml_type other_types[] = { - GGML_TYPE_Q4_1, - GGML_TYPE_Q5_0, GGML_TYPE_Q5_1, - GGML_TYPE_Q8_0, - GGML_TYPE_Q2_K, GGML_TYPE_Q3_K, - GGML_TYPE_Q5_K, - GGML_TYPE_Q6_K, - // GGML_TYPE_TQ1_0, GGML_TYPE_TQ2_0, // TODO: implement for all backends - GGML_TYPE_IQ2_XS, GGML_TYPE_IQ2_S, - GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ1_S, GGML_TYPE_IQ1_M, - GGML_TYPE_IQ4_NL, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS, - GGML_TYPE_BF16, - }; - // unary ops for (int v : {0, 1}) { for (int op = 0; op < GGML_UNARY_OP_COUNT; op++) { @@ -3392,6 +3381,14 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {2, 2})); } } + for (ggml_type type_a : other_types) { + for (ggml_type type_b : {GGML_TYPE_F32}) { + if (ggml_blck_size(type_a) != 256) { + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, ggml_blck_size(type_a), {1, 1}, {1, 1})); + } + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {1, 1}, {1, 1})); + } + } #else // m = a rows // n = b rows @@ -3411,15 +3408,6 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op } #endif - for (ggml_type type_a : other_types) { - for (ggml_type type_b : {GGML_TYPE_F32}) { - if (ggml_blck_size(type_a) != 256) { - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, ggml_blck_size(type_a), {1, 1}, {1, 1})); - } - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {1, 1}, {1, 1})); - } - } - test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 64, 2, 128, { 8, 1}, {1, 1})); test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 83, 2, 128, { 8, 1}, {4, 1})); test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 64, 2, 64, { 8, 1}, {4, 1})); @@ -3624,20 +3612,30 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_falcon(2)); #endif - // run tests - if (mode == MODE_GRAD) { - size_t n_ok = 0; - for (auto & test : test_cases) { - if (test->eval_grad(backend, op_name)) { - n_ok++; + return test_cases; +} + +// Test cases for performance evaluation: should be representative of real-world use cases +static std::vector> make_test_cases_perf() { + std::vector> test_cases; + + test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 1, 1, 1})); + test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 512, 1, 1})); + + for (int bs : {1, 512}) { + for (ggml_type type_a : all_types) { + for (ggml_type type_b : {GGML_TYPE_F32}) { + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, bs, 14336, {1, 1}, {1, 1})); } } - printf(" %zu/%zu tests passed\n", n_ok, test_cases.size()); - - return n_ok == test_cases.size(); } + return test_cases; +} + +static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op_name) { if (mode == MODE_TEST) { + auto test_cases = make_test_cases_eval(); ggml_backend_t backend_cpu = ggml_backend_cpu_init(); size_t n_ok = 0; @@ -3653,7 +3651,21 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op return n_ok == test_cases.size(); } + if (mode == MODE_GRAD) { + auto test_cases = make_test_cases_eval(); + size_t n_ok = 0; + for (auto & test : test_cases) { + if (test->eval_grad(backend, op_name)) { + n_ok++; + } + } + printf(" %zu/%zu tests passed\n", n_ok, test_cases.size()); + + return n_ok == test_cases.size(); + } + if (mode == MODE_PERF) { + auto test_cases = make_test_cases_perf(); for (auto & test : test_cases) { test->eval_perf(backend, op_name); } @@ -3667,9 +3679,9 @@ static void usage(char ** argv) { printf("Usage: %s [mode] [-o op] [-b backend]\n", argv[0]); printf(" valid modes:\n"); printf(" - test (default, compare with CPU backend for correctness)\n"); - printf(" - perf (performance evaluation)\n"); printf(" - grad (compare gradients from backpropagation with method of finite differences)\n"); - printf(" op names are as given by ggml_op_desc() (e.g. GGML_ADD)\n"); + printf(" - perf (performance evaluation)\n"); + printf(" op names for -o are as given by ggml_op_desc() (e.g. ADD, MUL_MAT, etc)\n"); } int main(int argc, char ** argv) { @@ -3728,6 +3740,11 @@ int main(int argc, char ** argv) { continue; } + if (ggml_backend_is_cpu(backend)) { + // TODO: better value for n_threads + ggml_backend_cpu_set_n_threads(backend, std::thread::hardware_concurrency() / 2); + } + printf(" Backend name: %s\n", ggml_backend_name(backend)); bool ok = test_backend(backend, mode, op_name_filter); From f4d2b8846a6b34419ff9e9491aee6cd95e444bfc Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 28 Sep 2024 17:42:03 +0300 Subject: [PATCH 10/33] llama : add reranking support (#9510) * py : add XLMRobertaForSequenceClassification [no ci] * py : fix scalar-tensor conversion [no ci] * py : fix position embeddings chop [no ci] * llama : read new cls tensors [no ci] * llama : add classigication head (wip) [no ci] * llama : add "rank" pooling type ggml-ci * server : add rerank endpoint ggml-ci * llama : aboud ggml_repeat during classification * rerank : cleanup + comments * server : accept /rerank endpoint in addition to /v1/rerank [no ci] * embedding : parse special tokens * jina : support v1 reranker * vocab : minor style ggml-ci * server : initiate tests for later ggml-ci * server : add docs * llama : add comment [no ci] * llama : fix uninitialized tensors * ci : add rerank tests ggml-ci * add reranking test * change test data * Update examples/server/server.cpp Co-authored-by: Xuan Son Nguyen * add `--reranking` argument * update server docs * llama : fix comment [no ci] ggml-ci --------- Co-authored-by: Xuan Son Nguyen Co-authored-by: Xuan Son Nguyen --- ci/run.sh | 85 ++++++- common/arg.cpp | 18 +- common/common.cpp | 5 + common/common.h | 1 + convert_hf_to_gguf.py | 27 ++- convert_hf_to_gguf_update.py | 1 + examples/embedding/embedding.cpp | 7 +- examples/server/README.md | 39 +++- examples/server/server.cpp | 215 ++++++++++++++++-- .../server/tests/features/embeddings.feature | 2 +- examples/server/tests/features/rerank.feature | 42 ++++ examples/server/tests/features/steps/steps.py | 54 ++++- examples/server/utils.hpp | 25 +- gguf-py/gguf/constants.py | 7 + gguf-py/gguf/tensor_mapping.py | 9 + include/llama.h | 10 +- src/llama-vocab.cpp | 15 +- src/llama.cpp | 96 +++++++- 18 files changed, 602 insertions(+), 56 deletions(-) create mode 100644 examples/server/tests/features/rerank.feature diff --git a/ci/run.sh b/ci/run.sh index 1ac08ee4e..7d241ecc0 100755 --- a/ci/run.sh +++ b/ci/run.sh @@ -712,6 +712,81 @@ function gg_run_embd_bge_small { set +e } +function gg_sum_embd_bge_small { + gg_printf '### %s\n\n' "${ci}" + + gg_printf 'BGE Small (BERT):\n' + gg_printf '- status: %s\n' "$(cat $OUT/${ci}.exit)" + gg_printf '- f16: \n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-f16.log)" + gg_printf '- q8_0:\n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-q8_0.log)" +} + +# rerank_tiny + +function gg_run_rerank_tiny { + cd ${SRC} + + gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/config.json + gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/tokenizer.json + gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/tokenizer_config.json + gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/special_tokens_map.json + gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/resolve/main/pytorch_model.bin + gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/sentence_bert_config.json + gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/vocab.txt + gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/modules.json + gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/config.json + + gg_wget models-mnt/rerank-tiny/1_Pooling https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/1_Pooling/config.json + + path_models="../models-mnt/rerank-tiny" + + rm -rf build-ci-release && mkdir build-ci-release && cd build-ci-release + + set -e + + (time cmake -DCMAKE_BUILD_TYPE=Release ${CMAKE_EXTRA} .. ) 2>&1 | tee -a $OUT/${ci}-cmake.log + (time make -j$(nproc) ) 2>&1 | tee -a $OUT/${ci}-make.log + + python3 ../convert_hf_to_gguf.py ${path_models} --outfile ${path_models}/ggml-model-f16.gguf + + model_f16="${path_models}/ggml-model-f16.gguf" + + (time ./bin/llama-embedding --model ${model_f16} -p "what is panda?hi\nwhat is panda?it's a bear\nwhat is panda?The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China." --pooling rank --embd-normalize -1 --verbose-prompt) 2>&1 | tee -a $OUT/${ci}-rk-f16.log + + # sample output + # rerank score 0: 0.029 + # rerank score 1: 0.029 + # rerank score 2: 0.135 + + # check that the score is in the range [$3, $4] + function check_score { + qnt="$1" + score=$(echo "$2" | grep -oE "[0-9]+\.[0-9]+" | tail -n 1) + + if [ $(echo "$score < $3" | bc) -eq 1 ] || [ $(echo "$score > $4" | bc) -eq 1 ]; then + printf ' - %s @ %s (FAIL: score not in range [%s, %s])\n' "$qnt" "$score" "$3" "$4" + return 20 + fi + + printf ' - %s @ %s OK\n' "$qnt" "$score" + return 0 + } + + check_score "rerank score 0" "$(cat $OUT/${ci}-rk-f16.log | grep "rerank score 0")" "0.00" "0.05" | tee -a $OUT/${ci}-rk-f16.log + check_score "rerank score 1" "$(cat $OUT/${ci}-rk-f16.log | grep "rerank score 1")" "0.00" "0.05" | tee -a $OUT/${ci}-rk-f16.log + check_score "rerank score 2" "$(cat $OUT/${ci}-rk-f16.log | grep "rerank score 2")" "0.10" "0.15" | tee -a $OUT/${ci}-rk-f16.log + + set +e +} + +function gg_sum_rerank_tiny { + gg_printf '### %s\n\n' "${ci}" + + gg_printf 'Rerank Tiny (Jina):\n' + gg_printf '- status: %s\n' "$(cat $OUT/${ci}.exit)" + gg_printf '- f16: \n```\n%s\n```\n' "$(cat $OUT/${ci}-rk-f16.log)" +} + function gg_check_build_requirements { if ! command -v cmake &> /dev/null; then gg_printf 'cmake not found, please install' @@ -726,15 +801,6 @@ function gg_check_build_requirements { fi } -function gg_sum_embd_bge_small { - gg_printf '### %s\n\n' "${ci}" - - gg_printf 'BGE Small (BERT):\n' - gg_printf '- status: %s\n' "$(cat $OUT/${ci}.exit)" - gg_printf '- f16: \n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-f16.log)" - gg_printf '- q8_0:\n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-q8_0.log)" -} - ## main export LLAMA_LOG_PREFIX=1 @@ -762,6 +828,7 @@ test $ret -eq 0 && gg_run ctest_release if [ -z ${GG_BUILD_LOW_PERF} ]; then test $ret -eq 0 && gg_run embd_bge_small + test $ret -eq 0 && gg_run rerank_tiny if [ -z ${GG_BUILD_CLOUD} ] || [ ${GG_BUILD_EXTRA_TESTS_0} ]; then test $ret -eq 0 && gg_run test_scripts_debug diff --git a/common/arg.cpp b/common/arg.cpp index 6880117ed..8266a16c2 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -284,6 +284,10 @@ static bool gpt_params_parse_ex(int argc, char ** argv, gpt_params_context & ctx params.kv_overrides.back().key[0] = 0; } + if (params.reranking && params.embedding) { + throw std::invalid_argument("error: either --embedding or --reranking can be specified, but not both"); + } + return true; } @@ -391,7 +395,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex, [](gpt_params & params) { params.verbose_prompt = true; } - ).set_examples({LLAMA_EXAMPLE_MAIN})); + )); add_opt(llama_arg( {"--no-display-prompt"}, format("don't print prompt at generation (default: %s)", !params.display_prompt ? "true" : "false"), @@ -1093,13 +1097,14 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex, } ).set_sparam()); add_opt(llama_arg( - {"--pooling"}, "{none,mean,cls,last}", + {"--pooling"}, "{none,mean,cls,last,rank}", "pooling type for embeddings, use model default if unspecified", [](gpt_params & params, const std::string & value) { /**/ if (value == "none") { params.pooling_type = LLAMA_POOLING_TYPE_NONE; } else if (value == "mean") { params.pooling_type = LLAMA_POOLING_TYPE_MEAN; } - else if (value == "cls") { params.pooling_type = LLAMA_POOLING_TYPE_CLS; } + else if (value == "cls") { params.pooling_type = LLAMA_POOLING_TYPE_CLS; } else if (value == "last") { params.pooling_type = LLAMA_POOLING_TYPE_LAST; } + else if (value == "rank") { params.pooling_type = LLAMA_POOLING_TYPE_RANK; } else { throw std::invalid_argument("invalid value"); } } ).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_POOLING")); @@ -1749,6 +1754,13 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex, params.embedding = true; } ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_EMBEDDINGS")); + add_opt(llama_arg( + {"--reranking", "--rerank"}, + format("enable reranking endpoint on server (default: %s)", params.reranking ? "enabled" : "disabled"), + [](gpt_params & params) { + params.reranking = true; + } + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_RERANKING")); add_opt(llama_arg( {"--api-key"}, "KEY", "API key to use for authentication (default: none)", diff --git a/common/common.cpp b/common/common.cpp index 8d0ed4f95..e2b8574bf 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1023,6 +1023,11 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param cparams.flash_attn = params.flash_attn; cparams.no_perf = params.no_perf; + if (params.reranking) { + cparams.embeddings = true; + cparams.pooling_type = LLAMA_POOLING_TYPE_RANK; + } + cparams.type_k = kv_cache_type_from_str(params.cache_type_k); cparams.type_v = kv_cache_type_from_str(params.cache_type_v); diff --git a/common/common.h b/common/common.h index cb87c4479..8b84cf9ad 100644 --- a/common/common.h +++ b/common/common.h @@ -271,6 +271,7 @@ struct gpt_params { int32_t embd_normalize = 2; // normalisation for embendings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm) std::string embd_out = ""; // empty = default, "array" = [[],[]...], "json" = openai style, "json+" = same "json" + cosine similarity matrix std::string embd_sep = "\n"; // separator of embendings + bool reranking = false; // enable reranking support on server // server params int32_t port = 8080; // server listens on this network port diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 2cd5a8c11..96a8830e9 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -291,8 +291,13 @@ class Model: bid = int(part) break - for new_name, data in ((n, d.squeeze().numpy()) for n, d in self.modify_tensors(data_torch, name, bid)): - data: np.ndarray # type hint + for new_name, data_torch in (self.modify_tensors(data_torch, name, bid)): + data = data_torch.squeeze().numpy() + + # if data ends up empty, it means data_torch was a scalar tensor -> restore + if len(data.shape) == 0: + data = data_torch.numpy() + n_dims = len(data.shape) data_qtype: gguf.GGMLQuantizationType | bool = self.tensor_force_quant(name, new_name, bid, n_dims) @@ -592,6 +597,9 @@ class Model: if chkhsh == "a8594e3edff7c29c003940395316294b2c623e09894deebbc65f33f1515df79e": # ref: https://huggingface.co/databricks/dbrx-base res = "dbrx" + if chkhsh == "c7699093ba4255a91e702aa38a596aa81669f3525dae06c2953267dde580f448": + # ref: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en + res = "jina-v1-en" if chkhsh == "0876d13b50744004aa9aeae05e7b0647eac9d801b5ba4668afc01e709c15e19f": # ref: https://huggingface.co/jinaai/jina-embeddings-v2-base-en res = "jina-v2-en" @@ -2601,7 +2609,7 @@ class NomicBertModel(BertModel): self.gguf_writer.add_rope_freq_base(self.hparams["rotary_emb_base"]) -@Model.register("XLMRobertaModel") +@Model.register("XLMRobertaModel", "XLMRobertaForSequenceClassification") class XLMRobertaModel(BertModel): model_arch = gguf.MODEL_ARCH.BERT @@ -2699,6 +2707,11 @@ class XLMRobertaModel(BertModel): self.gguf_writer.add_add_eos_token(True) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # if name starts with "roberta.", remove the prefix + # e.g. https://huggingface.co/BAAI/bge-reranker-v2-m3/tree/main + if name.startswith("roberta."): + name = name[8:] + # position embeddings start at pad_token_id + 1, so just chop down the weight tensor if name == "embeddings.position_embeddings.weight": if self._position_offset is not None: @@ -3110,6 +3123,14 @@ class JinaBertV2Model(BertModel): self.gguf_writer.add_add_bos_token(True) self.gguf_writer.add_add_eos_token(True) + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # if name starts with "bert.", remove the prefix + # e.g. https://huggingface.co/jinaai/jina-reranker-v1-tiny-en + if name.startswith("bert."): + name = name[5:] + + return super().modify_tensors(data_torch, name, bid) + @Model.register("OpenELMForCausalLM") class OpenELMModel(Model): diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py index 4d11059f3..022354a3b 100755 --- a/convert_hf_to_gguf_update.py +++ b/convert_hf_to_gguf_update.py @@ -81,6 +81,7 @@ models = [ {"name": "qwen2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen1.5-7B", }, {"name": "olmo", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/allenai/OLMo-1.7-7B-hf", }, {"name": "dbrx", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/databricks/dbrx-base", }, + {"name": "jina-v1-en", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-reranker-v1-tiny-en", }, {"name": "jina-v2-en", "tokt": TOKENIZER_TYPE.WPM, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-en", }, # WPM! {"name": "jina-v2-es", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-es", }, {"name": "jina-v2-de", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-de", }, diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index a438dcb5a..734926822 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -135,7 +135,7 @@ int main(int argc, char ** argv) { // tokenize the prompts and trim std::vector> inputs; for (const auto & prompt : prompts) { - auto inp = ::llama_tokenize(ctx, prompt, true, false); + auto inp = ::llama_tokenize(ctx, prompt, true, true); if (inp.size() > n_batch) { LOG_ERR("%s: number of tokens in input line (%lld) exceeds batch size (%lld), increase batch size and re-run\n", __func__, (long long int) inp.size(), (long long int) n_batch); @@ -234,6 +234,11 @@ int main(int argc, char ** argv) { } LOG("\n"); } + } else if (pooling_type == LLAMA_POOLING_TYPE_RANK) { + for (int j = 0; j < n_embd_count; j++) { + // NOTE: if you change this log - update the tests in ci/run.sh + LOG("rerank score %d: %8.3f\n", j, emb[j * n_embd]); + } } else { // print the first part of the embeddings or for a single prompt, the full embedding for (int j = 0; j < n_prompts; j++) { diff --git a/examples/server/README.md b/examples/server/README.md index dfca07f98..951c4a44c 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -7,6 +7,7 @@ Set of LLM REST APIs and a simple web front end to interact with llama.cpp. **Features:** * LLM inference of F16 and quantized models on GPU and CPU * [OpenAI API](https://github.com/openai/openai-openapi) compatible chat completions and embeddings routes + * Reranking endoint (WIP: https://github.com/ggerganov/llama.cpp/pull/9510) * Parallel decoding with multi-user support * Continuous batching * Multimodal (wip) @@ -23,6 +24,7 @@ The project is under active development, and we are [looking for feedback and co | -------- | ----------- | | `-h, --help, --usage` | print usage and exit | | `--version` | show version and build info | +| `--verbose-prompt` | print a verbose prompt before generation (default: false) | | `-t, --threads N` | number of threads to use during generation (default: -1)
(env: LLAMA_ARG_THREADS) | | `-tb, --threads-batch N` | number of threads to use during batch and prompt processing (default: same as --threads) | | `-C, --cpu-mask M` | CPU affinity mask: arbitrarily long hex. Complements cpu-range (default: "") | @@ -130,7 +132,7 @@ The project is under active development, and we are [looking for feedback and co | `--no-context-shift` | disables context shift on inifinite text generation (default: disabled)
(env: LLAMA_ARG_NO_CONTEXT_SHIFT) | | `-sp, --special` | special tokens output enabled (default: false) | | `--spm-infill` | use Suffix/Prefix/Middle pattern for infill (instead of Prefix/Suffix/Middle) as some models prefer this. (default: disabled) | -| `--pooling {none,mean,cls,last}` | pooling type for embeddings, use model default if unspecified
(env: LLAMA_ARG_POOLING) | +| `--pooling {none,mean,cls,last,rank}` | pooling type for embeddings, use model default if unspecified
(env: LLAMA_ARG_POOLING) | | `-cb, --cont-batching` | enable continuous batching (a.k.a dynamic batching) (default: enabled)
(env: LLAMA_ARG_CONT_BATCHING) | | `-nocb, --no-cont-batching` | disable continuous batching
(env: LLAMA_ARG_NO_CONT_BATCHING) | | `-a, --alias STRING` | set alias for model name (to be used by REST API)
(env: LLAMA_ARG_ALIAS) | @@ -138,6 +140,7 @@ The project is under active development, and we are [looking for feedback and co | `--port PORT` | port to listen (default: 8080)
(env: LLAMA_ARG_PORT) | | `--path PATH` | path to serve static files from (default: )
(env: LLAMA_ARG_STATIC_PATH) | | `--embedding, --embeddings` | restrict to only support embedding use case; use only with dedicated embedding models (default: disabled)
(env: LLAMA_ARG_EMBEDDINGS) | +| `--reranking, --rerank` | enable reranking endpoint on server (default: disabled)
(env: LLAMA_ARG_RERANKING) | | `--api-key KEY` | API key to use for authentication (default: none)
(env: LLAMA_API_KEY) | | `--api-key-file FNAME` | path to file containing API keys (default: none) | | `--ssl-key-file FNAME` | path to file a PEM-encoded SSL private key
(env: LLAMA_ARG_SSL_KEY_FILE) | @@ -152,6 +155,7 @@ The project is under active development, and we are [looking for feedback and co | `-sps, --slot-prompt-similarity SIMILARITY` | how much the prompt of a request must match the prompt of a slot in order to use that slot (default: 0.50, 0.0 = disabled)
| | `--lora-init-without-apply` | load LoRA adapters without applying them (apply later via POST /lora-adapters) (default: disabled) | + Note: If both command line argument and environment variable are both set for the same param, the argument will take precedence over env var. Example usage of docker compose with environment variables: @@ -478,6 +482,39 @@ The same as [the embedding example](../embedding) does. `image_data`: An array of objects to hold base64-encoded image `data` and its `id`s to be reference in `content`. You can determine the place of the image in the content as in the following: `Image: [img-21].\nCaption: This is a picture of a house`. In this case, `[img-21]` will be replaced by the embeddings of the image with id `21` in the following `image_data` array: `{..., "image_data": [{"data": "", "id": 21}]}`. Use `image_data` only with multimodal models, e.g., LLaVA. +### POST `/reranking`: Rerank documents according to a given query + +Similar to https://jina.ai/reranker/ but might change in the future. +Requires a reranker model (such as [bge-reranker-v2-m3](https://huggingface.co/BAAI/bge-reranker-v2-m3)) and the `--embedding --pooling rank` options. + + *Options:* + + `query`: The query against which the documents will be ranked. + + `documents`: An array strings representing the documents to be ranked. + + *Aliases:* + - `/rerank` + - `/v1/rerank` + - `/v1/reranking` + + *Examples:* + + ```shell + curl http://127.0.0.1:8012/v1/rerank \ + -H "Content-Type: application/json" \ + -d '{ + "model": "some-model", + "query": "What is panda?", + "top_n": 3, + "documents": [ + "hi", + "it is a bear", + "The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China." + ] + }' | jq + ``` + ### POST `/infill`: For code infilling. Takes a prefix and a suffix and returns the predicted completion as stream. diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 61ff09bb2..f343cc252 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -92,6 +92,7 @@ enum server_task_type { enum server_task_cmpl_type { SERVER_TASK_CMPL_TYPE_NORMAL, SERVER_TASK_CMPL_TYPE_EMBEDDING, + SERVER_TASK_CMPL_TYPE_RERANK, SERVER_TASK_CMPL_TYPE_INFILL, }; @@ -172,6 +173,7 @@ struct server_slot { std::vector generated_token_probs; server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL; + bool has_next_token = true; bool truncated = false; bool stopped_eos = false; @@ -954,8 +956,17 @@ struct server_context { slot.prompt = *prompt; } else if (prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_array()) { slot.prompt = prompt->at(0); + } else if (prompt->is_array() && prompt->size() > 1) { + // array of strings + for (const auto & el : *prompt) { + if (!el.is_string()) { + send_error(task, "\"prompt\" must be a string, an array of strings or an array of integers", ERROR_TYPE_INVALID_REQUEST); + return false; + } + } + slot.prompt = *prompt; } else { - send_error(task, "\"prompt\" must be a string or an array of integers", ERROR_TYPE_INVALID_REQUEST); + send_error(task, "\"prompt\" must be a string, an array of strings or an array of integers", ERROR_TYPE_INVALID_REQUEST); return false; } } @@ -1389,6 +1400,7 @@ struct server_context { res.data = json { {"embedding", std::vector(n_embd, 0.0f)}, + {"index", slot.index}, }; continue; @@ -1407,6 +1419,44 @@ struct server_context { queue_results.send(res); } + void send_rerank(const server_slot & slot, const llama_batch & batch) { + server_task_result res; + res.id = slot.id_task; + res.error = false; + res.stop = true; + + for (int i = 0; i < batch.n_tokens; ++i) { + if (!batch.logits[i] || batch.seq_id[i][0] != slot.id + 1) { + continue; + } + + const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); + if (embd == NULL) { + embd = llama_get_embeddings_ith(ctx, i); + } + + if (embd == NULL) { + SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]); + + res.data = json { + {"index", slot.index}, + {"score", -1e6}, + }; + + continue; + } + + res.data = json { + {"index", slot.index}, + {"score", embd[0]}, + }; + } + + SLT_DBG(slot, "sending rerank result, res = '%s'\n", res.data.dump().c_str()); + + queue_results.send(res); + } + // // Functions to create new task(s) and receive result(s) // @@ -1442,13 +1492,27 @@ struct server_context { // otherwise, it's a multiple-prompt task, we break it into smaller tasks else if (prompt.is_array()) { std::vector prompts = prompt; - for (size_t i = 0; i < prompts.size(); i++) { - const auto & e = prompts[i]; - if (e.is_string() || json_is_array_of_numbers(e)) { - data["index"] = i; - create_task(data, true, e); - } else { - throw std::runtime_error(error_msg); + if (cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) { + // prompts[0] is the question + // the rest are the answers/documents + SRV_DBG("creating rerank tasks, n_prompts = %d\n", (int) prompts.size() - 1); + for (size_t i = 1; i < prompts.size(); i++) { + json qd; + qd.push_back(prompts[0]); + qd.push_back(prompts[i]); + data["index"] = i - 1; + create_task(data, true, qd); + } + } else { + SRV_DBG("creating multi-prompt tasks, n_prompts = %d\n", (int) prompts.size()); + for (size_t i = 0; i < prompts.size(); i++) { + const auto & e = prompts[i]; + if (e.is_string() || json_is_array_of_numbers(e)) { + data["index"] = i; + create_task(data, true, e); + } else { + throw std::runtime_error(error_msg); + } } } } @@ -1492,7 +1556,9 @@ struct server_context { return; } - size_t idx = result.data["index"]; + const size_t idx = result.data["index"]; + GGML_ASSERT(idx < results.size() && "index out of range"); + results[idx] = result; } result_handler(results); @@ -1903,6 +1969,7 @@ struct server_context { // track if this is an embedding or non-embedding batch // if we've added sampled tokens above, we are in non-embedding mode // -1: none, 0: non-embedding, 1: embedding + // TODO: make enum int32_t batch_type = batch.n_tokens > 0 ? 0 : -1; // next, batch any pending prompts without exceeding n_batch @@ -1951,6 +2018,29 @@ struct server_context { } prompt_tokens = embd_inp; + } else if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) { + // require slot.prompt to be array of 2 strings + if (!slot.prompt.is_array() || slot.prompt.size() != 2) { + SLT_ERR(slot, "%s", "invalid prompt for rerank task\n"); + slot.release(); + send_error(slot, "invalid prompt for rerank task", ERROR_TYPE_INVALID_REQUEST); + continue; + } + + // prompt: querydoc + prompt_tokens.clear(); + prompt_tokens.push_back(llama_token_bos(model)); + { + const auto part = tokenize(slot.prompt[0], false); + prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end()); + } + prompt_tokens.push_back(llama_token_eos(model)); + prompt_tokens.push_back(llama_token_bos(model)); + { + const auto part = tokenize(slot.prompt[1], false); + prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end()); + } + prompt_tokens.push_back(llama_token_eos(model)); } else { prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt } @@ -1970,7 +2060,7 @@ struct server_context { continue; } - if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING) { + if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) { // this prompt is too large to process - discard it if (slot.n_prompt_tokens > n_ubatch) { slot.release(); @@ -2048,7 +2138,8 @@ struct server_context { slot.n_prompt_tokens_processed = 0; } - if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING) { + // non-causal tasks require to fit the entire prompt in the physical batch + if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) { // cannot fit the prompt in the current batch - will try next iter if (batch.n_tokens + slot.n_prompt_tokens > n_batch) { continue; @@ -2056,7 +2147,10 @@ struct server_context { } // check that we are in the right batch_type, if not defer the slot - bool slot_type = slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING ? 1 : 0; + const bool slot_type = + slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || + slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK ? 1 : 0; + if (batch_type == -1) { batch_type = slot_type; } else if (batch_type != slot_type) { @@ -2229,6 +2323,13 @@ struct server_context { continue; // continue loop of slots } + if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) { + send_rerank(slot, batch_view); + slot.release(); + slot.i_batch = -1; + continue; // continue loop of slots + } + // prompt evaluated for next-token prediction slot.state = SLOT_STATE_GENERATING; } else if (slot.state != SLOT_STATE_GENERATING) { @@ -2787,8 +2888,8 @@ int main(int argc, char ** argv) { }; const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_cmpl_type cmpl_type, json & data, httplib::Response & res) { - if (ctx_server.params.embedding) { - res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); + if (ctx_server.params.embedding || ctx_server.params.reranking) { + res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED)); return; } @@ -2848,8 +2949,8 @@ int main(int argc, char ** argv) { // TODO: maybe merge this function with "handle_completions_generic" const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error, &res_ok, verbose](const httplib::Request & req, httplib::Response & res) { - if (ctx_server.params.embedding) { - res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); + if (ctx_server.params.embedding || ctx_server.params.reranking) { + res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED)); return; } @@ -2973,6 +3074,11 @@ int main(int argc, char ** argv) { }; const auto handle_embeddings = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) { + // TODO: somehow clean up this checks in the future + if (!ctx_server.params.embedding || ctx_server.params.reranking) { + res_error(res, format_error_response("This server does not support embeddings. Start it with `--embeddings` and without `--reranking`", ERROR_TYPE_NOT_SUPPORTED)); + return; + } const json body = json::parse(req.body); bool is_openai = false; @@ -3023,6 +3129,79 @@ int main(int argc, char ** argv) { res_ok(res, root); }; + const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) { + if (!ctx_server.params.reranking) { + res_error(res, format_error_response("This server does not support reranking. Start it with `--reranking`", ERROR_TYPE_NOT_SUPPORTED)); + return; + } + const json body = json::parse(req.body); + + // TODO: implement + //int top_n = 1; + //if (body.count("top_n") != 1) { + // top_n = body.at("top_n"); + //} else { + // res_error(res, format_error_response("\"top_n\" must be provided", ERROR_TYPE_INVALID_REQUEST)); + // return; + //} + + json query; + if (body.count("query") == 1) { + query = body.at("query"); + if (!query.is_string()) { + res_error(res, format_error_response("\"query\" must be a string", ERROR_TYPE_INVALID_REQUEST)); + return; + } + } else { + res_error(res, format_error_response("\"query\" must be provided", ERROR_TYPE_INVALID_REQUEST)); + return; + } + + std::vector documents = json_value(body, "documents", std::vector()); + if (documents.empty()) { + res_error(res, format_error_response("\"documents\" must be a non-empty string array", ERROR_TYPE_INVALID_REQUEST)); + return; + } + + // construct prompt object: array of ["query", "doc0", "doc1", ...] + json prompt; + prompt.push_back(query); + for (const auto & doc : documents) { + prompt.push_back(doc); + } + + LOG_DBG("rerank prompt: %s\n", prompt.dump().c_str()); + + // create and queue the task + json responses = json::array(); + bool error = false; + { + std::vector tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_RERANK); + ctx_server.queue_results.add_waiting_tasks(tasks); + ctx_server.queue_tasks.post(tasks); + + // get the result + std::unordered_set task_ids = server_task::get_list_id(tasks); + + ctx_server.receive_cmpl_results(task_ids, [&](std::vector & results) { + for (const auto & res : results) { + responses.push_back(res.data); + } + }, [&](const json & error_data) { + res_error(res, error_data); + error = true; + }); + } + + if (error) { + return; + } + + // write JSON response + json root = format_response_rerank(body, responses); + res_ok(res, root); + }; + const auto handle_lora_adapters_list = [&](const httplib::Request &, httplib::Response & res) { json result = json::array(); for (size_t i = 0; i < ctx_server.loras.size(); ++i) { @@ -3119,6 +3298,10 @@ int main(int argc, char ** argv) { svr->Post("/embedding", handle_embeddings); // legacy svr->Post("/embeddings", handle_embeddings); svr->Post("/v1/embeddings", handle_embeddings); + svr->Post("/rerank", handle_rerank); + svr->Post("/reranking", handle_rerank); + svr->Post("/v1/rerank", handle_rerank); + svr->Post("/v1/reranking", handle_rerank); svr->Post("/tokenize", handle_tokenize); svr->Post("/detokenize", handle_detokenize); // LoRA adapters hotswap diff --git a/examples/server/tests/features/embeddings.feature b/examples/server/tests/features/embeddings.feature index 818ea3beb..f4fe2ee43 100644 --- a/examples/server/tests/features/embeddings.feature +++ b/examples/server/tests/features/embeddings.feature @@ -15,7 +15,7 @@ Feature: llama.cpp server And 128 as batch size And 128 as ubatch size And 512 KV cache size - And embeddings extraction + And enable embeddings endpoint Then the server is starting Then the server is healthy diff --git a/examples/server/tests/features/rerank.feature b/examples/server/tests/features/rerank.feature new file mode 100644 index 000000000..c36cc8e21 --- /dev/null +++ b/examples/server/tests/features/rerank.feature @@ -0,0 +1,42 @@ +@llama.cpp +@rerank +Feature: llama.cpp server + + Background: Server startup + Given a server listening on localhost:8080 + And a model url https://huggingface.co/ggml-org/models/resolve/main/jina-reranker-v1-tiny-en/ggml-model-f16.gguf + And a model file jina-reranker-v1-tiny-en.gguf + And a model alias jina-reranker-v1-tiny-en + And 42 as server seed + And 2 slots + And 512 as batch size + And 512 as ubatch size + And 512 KV cache size + And enable reranking endpoint + Then the server is starting + Then the server is healthy + + Scenario: Rerank + Given a rerank query: + """ + Machine learning is + """ + And a rerank document: + """ + A machine is a physical system that uses power to apply forces and control movement to perform an action. The term is commonly applied to artificial devices, such as those employing engines or motors, but also to natural biological macromolecules, such as molecular machines. + """ + And a rerank document: + """ + Learning is the process of acquiring new understanding, knowledge, behaviors, skills, values, attitudes, and preferences. The ability to learn is possessed by humans, non-human animals, and some machines; there is also evidence for some kind of learning in certain plants. + """ + And a rerank document: + """ + Machine learning is a field of study in artificial intelligence concerned with the development and study of statistical algorithms that can learn from data and generalize to unseen data, and thus perform tasks without explicit instructions. + """ + And a rerank document: + """ + Paris, capitale de la France, est une grande ville européenne et un centre mondial de l'art, de la mode, de la gastronomie et de la culture. Son paysage urbain du XIXe siècle est traversé par de larges boulevards et la Seine. + """ + When reranking request + Then reranking results are returned + Then reranking highest score is index 2 and lowest score is index 3 diff --git a/examples/server/tests/features/steps/steps.py b/examples/server/tests/features/steps/steps.py index 0fea0fe87..2611614ba 100644 --- a/examples/server/tests/features/steps/steps.py +++ b/examples/server/tests/features/steps/steps.py @@ -68,6 +68,7 @@ def step_server_config(context, server_fqdn: str, server_port: str): context.server_api_key = None context.server_continuous_batching = False context.server_embeddings = False + context.server_reranking = False context.server_metrics = False context.server_process = None context.seed = None @@ -83,6 +84,10 @@ def step_server_config(context, server_fqdn: str, server_port: str): context.concurrent_tasks = [] context.prompts = [] + context.reranking_query = None + context.reranking_documents = [] + context.reranking_results = None + @step('a model file {hf_file} from HF repo {hf_repo}') def step_download_hf_model(context, hf_file: str, hf_repo: str): @@ -172,10 +177,13 @@ def step_server_continuous_batching(context): context.server_continuous_batching = True -@step('embeddings extraction') +@step('enable embeddings endpoint') def step_server_embeddings(context): context.server_embeddings = True +@step('enable reranking endpoint') +def step_server_reranking(context): + context.server_reranking = True @step('prometheus compatible metrics exposed') def step_server_metrics(context): @@ -452,6 +460,14 @@ def step_impl(context, n_ga_w): def step_prompt_passkey(context): context.prompt_passkey = context_text(context) +@step('a rerank query') +def step_set_rerank_query(context): + context.reranking_query = context_text(context) + context.reranking_documents = [] + +@step('a rerank document') +def step_set_rerank_document(context): + context.reranking_documents.append(context_text(context)) @step('{n_prompts:d} fixed prompts') def step_fixed_prompts(context, n_prompts): @@ -619,6 +635,22 @@ async def step_compute_embedding(context): context.embeddings = await request_embedding(context_text(context), None, base_url=context.base_url) +@step('reranking request') +@async_run_until_complete +async def step_compute_reranking(context): + async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session: + async with session.post(f'{context.base_url}/reranking', + json={ + "query": context.reranking_query, + "documents": context.reranking_documents, + }) as response: + if response.status == 200: + response_json = await response.json() + context.reranking_results = response_json['results'] + else: + context.reranking_results = response.status + + @step('all embeddings are the same') @async_run_until_complete async def step_all_embeddings_are_the_same(context): @@ -704,6 +736,24 @@ async def all_embeddings_are_generated(context): for i in range(n_embedding_requests): assert_embeddings(context.tasks_result.pop().pop()) +@step('reranking results are returned') +def reranking_results_are_returned(context): + assert len(context.reranking_results) == len(context.reranking_documents) + +@step('reranking highest score is index {idx_high:d} and lowest score is index {idx_low:d}') +def reranking_results_are_returned(context, idx_high: int, idx_low: int): + max_score, max_idx = 0, 0 + min_score, min_idx = 0, 0 + for res in context.reranking_results: + if max_score < res['relevance_score']: + max_score = res['relevance_score'] + max_idx = res['index'] + if min_score > res['relevance_score']: + min_score = res['relevance_score'] + min_idx = res['index'] + print(context.reranking_results) + assert max_idx == idx_high + assert min_idx == idx_low @step('adding special tokens') def step_tokenize_set_add_special(context): @@ -1362,6 +1412,8 @@ def start_server_background(context): server_args.append('--cont-batching') if context.server_embeddings: server_args.append('--embedding') + if context.server_reranking: + server_args.append('--reranking') if context.server_metrics: server_args.append('--metrics') if context.model_alias: diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index f093f547f..47dfdfde5 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -537,7 +537,7 @@ static json format_embeddings_response_oaicompat(const json & request, const jso json res = json { {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, {"object", "list"}, - {"usage", json { + {"usage", json { // TODO: fill {"prompt_tokens", 0}, {"total_tokens", 0} }}, @@ -547,6 +547,29 @@ static json format_embeddings_response_oaicompat(const json & request, const jso return res; } +static json format_response_rerank(const json & request, const json & ranks) { + json data = json::array(); + int i = 0; + for (const auto & rank : ranks) { + data.push_back(json{ + {"index", i++}, + {"relevance_score", json_value(rank, "score", 0.0)}, + }); + } + + json res = json { + {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, + {"object", "list"}, + {"usage", json { // TODO: fill + {"prompt_tokens", 0}, + {"total_tokens", 0} + }}, + {"results", data} + }; + + return res; +} + static bool is_valid_utf8(const std::string & str) { const unsigned char* bytes = reinterpret_cast(str.data()); const unsigned char* end = bytes + str.length(); diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 2fd2e9d2b..ebe66a4a3 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -345,6 +345,8 @@ class MODEL_TENSOR(IntEnum): ENC_FFN_DOWN = auto() ENC_FFN_UP = auto() ENC_OUTPUT_NORM = auto() + CLS = auto() # classifier + CLS_OUT = auto() # classifier output projection MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { @@ -504,6 +506,8 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = { MODEL_TENSOR.ENC_FFN_DOWN: "enc.blk.{bid}.ffn_down", MODEL_TENSOR.ENC_FFN_UP: "enc.blk.{bid}.ffn_up", MODEL_TENSOR.ENC_OUTPUT_NORM: "enc.output_norm", + MODEL_TENSOR.CLS: "cls", + MODEL_TENSOR.CLS_OUT: "cls.output", } MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { @@ -613,6 +617,8 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, MODEL_TENSOR.LAYER_OUT_NORM, + MODEL_TENSOR.CLS, + MODEL_TENSOR.CLS_OUT, ], MODEL_ARCH.NOMIC_BERT: [ MODEL_TENSOR.TOKEN_EMBD, @@ -644,6 +650,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_GATE, MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.LAYER_OUT_NORM, + MODEL_TENSOR.CLS, ], MODEL_ARCH.MPT: [ MODEL_TENSOR.TOKEN_EMBD, diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 5ef91f11d..e7e9b6fd5 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -679,6 +679,15 @@ class TensorNameMap: MODEL_TENSOR.ENC_OUTPUT_NORM: ( "encoder.final_layer_norm", # t5 ), + + MODEL_TENSOR.CLS: ( + "classifier", # jina + "classifier.dense", # roberta + ), + + MODEL_TENSOR.CLS_OUT: ( + "classifier.out_proj", # roberta + ), } # architecture-specific block mappings diff --git a/include/llama.h b/include/llama.h index 4ea8a2c2b..7cae1bbe2 100644 --- a/include/llama.h +++ b/include/llama.h @@ -193,6 +193,7 @@ extern "C" { LLAMA_POOLING_TYPE_MEAN = 1, LLAMA_POOLING_TYPE_CLS = 2, LLAMA_POOLING_TYPE_LAST = 3, + LLAMA_POOLING_TYPE_RANK = 4, // used by reranking models to attach the classification head to the graph }; enum llama_attention_type { @@ -202,9 +203,9 @@ extern "C" { }; enum llama_split_mode { - LLAMA_SPLIT_MODE_NONE = 0, // single GPU - LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs - LLAMA_SPLIT_MODE_ROW = 2, // split rows across GPUs + LLAMA_SPLIT_MODE_NONE = 0, // single GPU + LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs + LLAMA_SPLIT_MODE_ROW = 2, // split rows across GPUs }; // TODO: simplify (https://github.com/ggerganov/llama.cpp/pull/9294#pullrequestreview-2286561979) @@ -872,7 +873,8 @@ extern "C" { // Get the embeddings for a sequence id // Returns NULL if pooling_type is LLAMA_POOLING_TYPE_NONE - // shape: [n_embd] (1-dimensional) + // when pooling_type == LLAMA_POOLING_TYPE_RANK, returns float[1] with the rank of the sequence + // otherwise: float[n_embd] (1-dimensional) LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id); // diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index e4d844a73..d2f34ddd6 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -1554,7 +1554,7 @@ std::vector llama_tokenize_internal( } break; case LLAMA_VOCAB_TYPE_UGM: { - if (add_special && vocab.tokenizer_add_bos != 0) { + if (add_special && vocab.tokenizer_add_bos) { GGML_ASSERT(vocab.special_bos_id != -1); output.push_back(vocab.special_bos_id); } @@ -1572,14 +1572,14 @@ std::vector llama_tokenize_internal( } } - if (add_special && vocab.tokenizer_add_bos != 0 && output.size() >= 2 && output[1] == vocab.special_bos_id) { + if (add_special && vocab.tokenizer_add_bos && output.size() >= 2 && output[1] == vocab.special_bos_id) { LLAMA_LOG_WARN( "%s: Added a BOS token to the prompt as specified by the model but the prompt " "also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. " "Are you sure this is what you want?\n", __FUNCTION__); } - if (add_special && vocab.tokenizer_add_eos == 1) { + if (add_special && vocab.tokenizer_add_eos) { GGML_ASSERT(vocab.special_eos_id != -1); output.push_back(vocab.special_eos_id); } @@ -1791,11 +1791,13 @@ int32_t llama_token_to_piece_impl(const struct llama_vocab & vocab, llama_token // suppressing them like CONTROL tokens. if (attr & (attr_special | LLAMA_TOKEN_ATTR_USER_DEFINED)) { return _try_copy(token_text.data(), token_text.size()); - } else if (attr & LLAMA_TOKEN_ATTR_NORMAL) { + } + if (attr & LLAMA_TOKEN_ATTR_NORMAL) { std::string result = token_text; llama_unescape_whitespace(result); return _try_copy(result.data(), result.size()); - } else if (attr & LLAMA_TOKEN_ATTR_BYTE) { + } + if (attr & LLAMA_TOKEN_ATTR_BYTE) { char byte = (char) llama_token_to_byte(vocab, token); return _try_copy((char*) &byte, 1); } @@ -1806,7 +1808,8 @@ int32_t llama_token_to_piece_impl(const struct llama_vocab & vocab, llama_token // suppressing them like CONTROL tokens. if (attr & (attr_special | LLAMA_TOKEN_ATTR_USER_DEFINED)) { return _try_copy(token_text.data(), token_text.size()); - } else if (attr & LLAMA_TOKEN_ATTR_NORMAL) { + } + if (attr & LLAMA_TOKEN_ATTR_NORMAL) { std::string result = llama_decode_text(token_text); return _try_copy(result.data(), result.size()); } diff --git a/src/llama.cpp b/src/llama.cpp index 44afb31d7..c466cd88b 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -606,6 +606,8 @@ enum llm_tensor { LLM_TENSOR_ENC_FFN_DOWN, LLM_TENSOR_ENC_FFN_UP, LLM_TENSOR_ENC_OUTPUT_NORM, + LLM_TENSOR_CLS, + LLM_TENSOR_CLS_OUT, }; static const std::map> LLM_TENSOR_NAMES = { @@ -793,6 +795,8 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" }, { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_CLS, "cls" }, + { LLM_TENSOR_CLS_OUT, "cls.output" }, }, }, { @@ -828,6 +832,7 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_CLS, "cls" }, }, }, { @@ -2894,6 +2899,7 @@ struct llama_model { llama_hparams hparams = {}; llama_vocab vocab; + // TODO: should init all tensors to nullptr struct ggml_tensor * tok_embd; struct ggml_tensor * type_embd; struct ggml_tensor * pos_embd; @@ -2906,6 +2912,12 @@ struct llama_model { struct ggml_tensor * output_b; struct ggml_tensor * output_norm_enc; + // classifier + struct ggml_tensor * cls; + struct ggml_tensor * cls_b; + struct ggml_tensor * cls_out = nullptr; + struct ggml_tensor * cls_out_b = nullptr; + std::vector layers; llama_split_mode split_mode; @@ -5604,11 +5616,11 @@ static void llm_load_hparams( ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, hparams.n_vocab_type); - ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type); + ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); hparams.f_max_alibi_bias = 8.0f; switch (hparams.n_layer) { - case 4: model.type = e_model::MODEL_33M; break; // jina-embeddings-small + case 4: model.type = e_model::MODEL_33M; break; // jina-embeddings-small case 12: model.type = e_model::MODEL_137M; break; // jina-embeddings-base } } break; @@ -6313,6 +6325,7 @@ static void llm_load_vocab( tokenizer_pre == "phi-2" || tokenizer_pre == "jina-es" || tokenizer_pre == "jina-de" || + tokenizer_pre == "jina-v1-en" || tokenizer_pre == "jina-v2-es" || tokenizer_pre == "jina-v2-de" || tokenizer_pre == "jina-v2-code") { @@ -6439,7 +6452,12 @@ static void llm_load_vocab( for (uint32_t i = 0; i < n_vocab; i++) { std::string word = gguf_get_arr_str(ctx, token_idx, i); - GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0); + + //GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0); + if (word.empty()) { + LLAMA_LOG_WARN("%s: empty token at index %u\n", __func__, i); + word = "[EMPTY_" + std::to_string(i) + "]"; + } vocab.token_to_id[word] = i; vocab.max_token_len = std::max(vocab.max_token_len, (int) word.size()); @@ -6520,8 +6538,14 @@ static void llm_load_vocab( vocab.linefeed_id = ids[0]; } else { const std::vector ids = llama_tokenize_internal(vocab, "\xC4\x8A", false); // U+010A - GGML_ASSERT(!ids.empty() && "model vocab missing newline token"); - vocab.linefeed_id = ids[0]; + + //GGML_ASSERT(!ids.empty() && "model vocab missing newline token"); + if (ids.empty()) { + LLAMA_LOG_WARN("%s: model vocab missing newline token, using special_pad_id instead\n", __func__); + vocab.linefeed_id = vocab.special_pad_id; + } else { + vocab.linefeed_id = ids[0]; + } } // special tokens @@ -7394,6 +7418,12 @@ static bool llm_load_tensors( if (model.arch == LLM_ARCH_BERT) { model.pos_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}); + + model.cls = ml.create_tensor(ctx_output, tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + model.cls_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_CLS, "bias"), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + + model.cls_out = ml.create_tensor(ctx_output, tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, 1}, llama_model_loader::TENSOR_NOT_REQUIRED); + model.cls_out_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_CLS_OUT, "bias"), {1}, llama_model_loader::TENSOR_NOT_REQUIRED); } model.tok_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}); @@ -7446,6 +7476,8 @@ static bool llm_load_tensors( model.tok_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}); // LayerNorm model.tok_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}); //LayerNorm bias + model.cls = ml.create_tensor(ctx_output, tn(LLM_TENSOR_CLS, "weight"), {n_embd, 1}, llama_model_loader::TENSOR_NOT_REQUIRED); + model.cls_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_CLS, "bias"), {1}, llama_model_loader::TENSOR_NOT_REQUIRED); for (int i = 0; i < n_layer; ++i) { ggml_context * ctx_layer = ctx_for_layer(i); ggml_context * ctx_split = ctx_for_layer_split(i); @@ -10279,6 +10311,10 @@ struct llm_build_context { struct ggml_tensor * cur; switch (pooling_type) { + case LLAMA_POOLING_TYPE_NONE: + { + cur = inp; + } break; case LLAMA_POOLING_TYPE_MEAN: { struct ggml_tensor * inp_mean = build_inp_mean(); @@ -10290,9 +10326,26 @@ struct llm_build_context { struct ggml_tensor * inp_cls = build_inp_cls(); cur = ggml_get_rows(ctx0, inp, inp_cls); } break; - case LLAMA_POOLING_TYPE_NONE: + case LLAMA_POOLING_TYPE_RANK: { - cur = inp; + struct ggml_tensor * inp_cls = build_inp_cls(); + inp = ggml_get_rows(ctx0, inp, inp_cls); + + // classification head + // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566 + GGML_ASSERT(model.cls != nullptr); + GGML_ASSERT(model.cls_b != nullptr); + + cur = ggml_add (ctx0, ggml_mul_mat(ctx0, model.cls, inp), model.cls_b); + cur = ggml_tanh(ctx0, cur); + + // some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en + // https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896 + if (model.cls_out) { + GGML_ASSERT(model.cls_out_b != nullptr); + + cur = ggml_add (ctx0, ggml_mul_mat(ctx0, model.cls_out, cur), model.cls_out_b); + } } break; default: { @@ -11521,8 +11574,8 @@ struct llm_build_context { inpL = cur; } - // final output cur = inpL; + cb(cur, "result_embd", -1); ggml_build_forward_expand(gf, cur); @@ -16682,7 +16735,9 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { } } - if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_CLS) { + if (cparams.embeddings && ( + cparams.pooling_type == LLAMA_POOLING_TYPE_CLS || + cparams.pooling_type == LLAMA_POOLING_TYPE_RANK)) { const int64_t n_tokens = batch.n_tokens; const int64_t n_seq_tokens = batch.n_seq_tokens; const int64_t n_seqs = batch.n_seqs; @@ -16697,7 +16752,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { const llama_seq_id seq_id = batch.seq_id[s][0]; // TODO: adapt limits to n_seqs when batch.equal_seqs is true - GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS"); + GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS or RANK"); for (int i = 0; i < n_seq_tokens; ++i) { const llama_pos pos = batch.pos[s*n_seq_tokens + i]; @@ -17237,6 +17292,20 @@ static int llama_decode_internal( ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float)); } } break; + case LLAMA_POOLING_TYPE_RANK: + { + // extract the rerank score - a single float per sequence + auto & embd_seq_out = lctx.embd_seq; + + for (uint32_t s = 0; s < ubatch.n_seqs; ++s) { + const llama_seq_id seq_id = ubatch.seq_id[s][0]; + if (embd_seq_out.find(seq_id) != embd_seq_out.end()) { + continue; + } + embd_seq_out[seq_id].resize(1); + ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float)); + } + } break; case LLAMA_POOLING_TYPE_UNSPECIFIED: { GGML_ABORT("unknown pooling type"); @@ -17443,6 +17512,13 @@ static int llama_encode_internal( ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float)); } } break; + case LLAMA_POOLING_TYPE_RANK: + { + // TODO: this likely should be the same logic as in llama_decoder_internal, but better to + // wait for an encoder model that requires this pooling type in order to test it + // https://github.com/ggerganov/llama.cpp/pull/9510 + GGML_ABORT("RANK pooling not implemented yet"); + } case LLAMA_POOLING_TYPE_UNSPECIFIED: { GGML_ABORT("unknown pooling type"); From 589b48d41efb0e95133b77c335f4fb9779af9bfb Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 29 Sep 2024 14:38:18 +0300 Subject: [PATCH 11/33] contrib : add Resources section (#9675) --- CONTRIBUTING.md | 5 +++++ README.md | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index a9e000e52..3d7c6f86c 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -27,3 +27,8 @@ ![matmul](media/matmul.png) +# Resources + +The Github issues, PRs and discussions contain a lot of information that can be useful to get familiar with the codebase. For convenience, some of the more important information is referenced from Github projects: + +https://github.com/ggerganov/llama.cpp/projects diff --git a/README.md b/README.md index a452a6d78..ecc2df8ca 100644 --- a/README.md +++ b/README.md @@ -443,7 +443,7 @@ To learn more how to measure perplexity using llama.cpp, [read this documentatio - Contributors can open PRs - Collaborators can push to branches in the `llama.cpp` repo and merge PRs into the `master` branch - Collaborators will be invited based on contributions -- Any help with managing issues and PRs is very appreciated! +- Any help with managing issues, PRs and projects is very appreciated! - See [good first issues](https://github.com/ggerganov/llama.cpp/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22) for tasks suitable for first contributions - Read the [CONTRIBUTING.md](CONTRIBUTING.md) for more information - Make sure to read this: [Inference at the edge](https://github.com/ggerganov/llama.cpp/discussions/205) From f99d3f8367174f7aba73c07fd87de687d4a0ece1 Mon Sep 17 00:00:00 2001 From: nopperl <54780682+nopperl@users.noreply.github.com> Date: Sun, 29 Sep 2024 12:02:06 +0000 Subject: [PATCH 12/33] py : add model class for Chameleon conversion (#9683) --- convert_hf_to_gguf.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 96a8830e9..f3857d487 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -4162,7 +4162,8 @@ class GraniteMoeModel(GraniteModel): return super().modify_tensors(data_torch, name, bid) -@Model.register("ChameleonForCausalLM") +@Model.register("ChameleonForConditionalGeneration") +@Model.register("ChameleonForCausalLM") # obsolete class ChameleonModel(Model): model_arch = gguf.MODEL_ARCH.CHAMELEON From faac0bae265449fd988c57bf894018edc36fbe1e Mon Sep 17 00:00:00 2001 From: matiaslin <45382001+matiaslin@users.noreply.github.com> Date: Sun, 29 Sep 2024 05:25:00 -0700 Subject: [PATCH 13/33] common : ensure llama_batch size does not exceed max size (#9668) A crash was observed when the number of tokens added to a batch exceeds llama_batch size. An assertion in llama_batch_add was added to protect against llama_batch size overflow. --- common/common.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/common/common.cpp b/common/common.cpp index e2b8574bf..a0611f3d1 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1437,6 +1437,8 @@ void llama_batch_add( llama_pos pos, const std::vector & seq_ids, bool logits) { + GGML_ASSERT(batch.seq_id[batch.n_tokens] && "llama_batch size exceeded"); + batch.token [batch.n_tokens] = id; batch.pos [batch.n_tokens] = pos; batch.n_seq_id[batch.n_tokens] = seq_ids.size(); From 6084bfb261b03f812de2255b05b6b5bb8d1c7171 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 24 Sep 2024 13:23:59 +0300 Subject: [PATCH 14/33] ggml : fix GGML_MAX_N_THREADS + improve formatting (ggml/969) --- ggml/include/ggml.h | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 9f96e0c48..f46d4a8a6 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -229,14 +229,16 @@ #define GGML_MAX_PARAMS 2048 #define GGML_MAX_CONTEXTS 64 #define GGML_MAX_SRC 10 -#ifndef GGML_MAX_NAME -#define GGML_MAX_NAME 64 #define GGML_MAX_N_THREADS 512 - -#endif #define GGML_MAX_OP_PARAMS 64 + +#ifndef GGML_MAX_NAME +# define GGML_MAX_NAME 64 +#endif + #define GGML_DEFAULT_N_THREADS 4 #define GGML_DEFAULT_GRAPH_SIZE 2048 + #if UINTPTR_MAX == 0xFFFFFFFF #define GGML_MEM_ALIGN 4 #else @@ -259,21 +261,21 @@ #define GGML_PAD(x, n) (((x) + (n) - 1) & ~((n) - 1)) #ifndef NDEBUG -#define GGML_UNREACHABLE() do { fprintf(stderr, "statement should be unreachable\n"); abort(); } while(0) +# define GGML_UNREACHABLE() do { fprintf(stderr, "statement should be unreachable\n"); abort(); } while(0) #elif defined(__GNUC__) -#define GGML_UNREACHABLE() __builtin_unreachable() +# define GGML_UNREACHABLE() __builtin_unreachable() #elif defined(_MSC_VER) -#define GGML_UNREACHABLE() __assume(0) +# define GGML_UNREACHABLE() __assume(0) #else -#define GGML_UNREACHABLE() ((void) 0) +# define GGML_UNREACHABLE() ((void) 0) #endif #ifdef __cplusplus -#define GGML_NORETURN [[noreturn]] +# define GGML_NORETURN [[noreturn]] #elif defined(_MSC_VER) -#define GGML_NORETURN __declspec(noreturn) +# define GGML_NORETURN __declspec(noreturn) #else -#define GGML_NORETURN _Noreturn +# define GGML_NORETURN _Noreturn #endif #define GGML_ABORT(...) ggml_abort(__FILE__, __LINE__, __VA_ARGS__) From 544f409b4bd8fc98a3e87820f0ac934e00402de7 Mon Sep 17 00:00:00 2001 From: Salvatore Mesoraca Date: Thu, 26 Sep 2024 08:59:42 +0200 Subject: [PATCH 15/33] vulkan : argsort barriers must be under uniform control flow (ggml/951) a return before a barrier (that happens only in some threads in a workgroup) leads to UB. While the old code actually works on some devices, it fails on some others (i.e. "smaller" GPUs). BTW, I think it would be better to set specialization constants when the graph is built, in that way the local workgroup could be sized appropriately. But it would take a lot of work. Signed-off-by: Salvatore Mesoraca --- ggml/src/vulkan-shaders/argsort.comp | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/ggml/src/vulkan-shaders/argsort.comp b/ggml/src/vulkan-shaders/argsort.comp index e55414b03..d4fa45b1e 100644 --- a/ggml/src/vulkan-shaders/argsort.comp +++ b/ggml/src/vulkan-shaders/argsort.comp @@ -29,20 +29,18 @@ void main() { const int col = int(gl_LocalInvocationID.x); const uint row = gl_WorkGroupID.y; - if (col >= p.ncols_pad) { - return; - } - const uint row_offset = row * p.ncols; // initialize indices - dst_row[col] = col; + if (col < p.ncols_pad) { + dst_row[col] = col; + } barrier(); for (uint k = 2; k <= p.ncols_pad; k *= 2) { for (uint j = k / 2; j > 0; j /= 2) { const uint ixj = col ^ j; - if (ixj > col) { + if (col < p.ncols_pad && ixj > col) { if ((col & k) == 0) { if (dst_row[col] >= p.ncols || (dst_row[ixj] < p.ncols && (p.order == ASC ? From 0de8b203f1d31cf5ee0d2a3560a0ad78d44f4d4c Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Fri, 27 Sep 2024 02:58:01 -0500 Subject: [PATCH 16/33] vulkan : fix build for GGML_VULKAN_RUN_TESTS, add TFLOPS to log (ggml/961) --- ggml/src/ggml-vulkan.cpp | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/ggml/src/ggml-vulkan.cpp b/ggml/src/ggml-vulkan.cpp index a877145e8..70b729154 100644 --- a/ggml/src/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan.cpp @@ -5013,6 +5013,8 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t } } + ggml_pipeline_allocate_descriptor_sets(ctx->device); + vk_buffer d_X = ggml_vk_create_buffer_check(ctx->device, sizeof(X_TYPE) * x_ne, vk::MemoryPropertyFlagBits::eDeviceLocal); vk_buffer d_Y = ggml_vk_create_buffer_check(ctx->device, sizeof(Y_TYPE) * y_ne, vk::MemoryPropertyFlagBits::eDeviceLocal); vk_buffer d_D = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne, vk::MemoryPropertyFlagBits::eDeviceLocal); @@ -5129,7 +5131,9 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t avg_err /= m * n; - std::cerr << "TEST " << shname << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time / num_it << "ms avg_err=" << avg_err << std::endl; + double tflops = 2.0*m*n*k*batch*num_it / (time / 1000.0) / (1000.0*1000.0*1000.0*1000.0); + + std::cerr << "TEST " << shname << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time / num_it << "ms " << tflops << " TFLOPS avg_err=" << avg_err << std::endl; if (avg_err > 0.1) { std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl; @@ -5251,12 +5255,14 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_ ggml_pipeline_request_descriptor_sets(ctx->device, p, 1); + ggml_pipeline_allocate_descriptor_sets(ctx->device); + ggml_vk_buffer_write(qx_buf, 0, qx, qx_sz); vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue); ggml_vk_ctx_begin(ctx->device, subctx); const std::vector pc = { 1, (uint32_t)ne, (uint32_t)ne, (uint32_t)ne, (uint32_t)ne }; - ggml_vk_dispatch_pipeline(ctx, subctx, p, { { qx_buf, 0, qx_sz }, { x_buf, 0, x_sz_f16 } }, pc.size() * sizeof(int), pc.data(), { (uint32_t)ne, 1, 1}); + ggml_vk_dispatch_pipeline(ctx, subctx, p, { vk_subbuffer{ qx_buf, 0, qx_sz }, vk_subbuffer{ x_buf, 0, x_sz_f16 } }, pc.size() * sizeof(int), pc.data(), { (uint32_t)ne, 1, 1}); ggml_vk_ctx_end(subctx); auto begin = std::chrono::high_resolution_clock::now(); @@ -5383,6 +5389,8 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, } } + ggml_pipeline_allocate_descriptor_sets(ctx->device); + ggml_vk_buffer_write(qx_buf, 0, qx, qx_sz); ggml_vk_buffer_write(y_buf, 0, y, y_sz); @@ -5450,7 +5458,9 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, avg_err /= m * n; - std::cerr << "TEST MMQ " << shname << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time_ms / num_it << "ms avg_err=" << avg_err << std::endl; + double tflops = 2.0*m*n*k*batch*num_it / (time_ms / 1000.0) / (1000.0*1000.0*1000.0*1000.0); + + std::cerr << "TEST MMQ " << shname << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time_ms / num_it << "ms " << tflops << " TFLOPS avg_err=" << avg_err << std::endl; if (avg_err > 0.01 || std::isnan(avg_err)) { std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl; @@ -5502,9 +5512,6 @@ static ggml_tensor_extra_gpu * ggml_vk_tensor_create_extra(ggml_tensor * tensor) static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) { #if defined(GGML_VULKAN_RUN_TESTS) - ctx->staging = ggml_vk_create_buffer_check(ctx->device, 100ul * 1024ul * 1024ul, - vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached, - vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent); ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_F32); ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q4_0); ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q4_1); From 641002fba8d2a0c0269027e23d2ef58e90546028 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Sun, 29 Sep 2024 11:50:17 -0500 Subject: [PATCH 17/33] vulkan : multithread pipeline creation (ggml/963) --- ggml/src/ggml-vulkan.cpp | 41 ++++++++++++++++++++++++++++++++++++---- 1 file changed, 37 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-vulkan.cpp b/ggml/src/ggml-vulkan.cpp index 70b729154..c677a2728 100644 --- a/ggml/src/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan.cpp @@ -20,6 +20,8 @@ #include #include #include +#include +#include #include "ggml-impl.h" #include "ggml-backend-impl.h" @@ -607,13 +609,16 @@ typedef void (*ggml_vk_func_t)(ggml_backend_vk_context * ctx, vk_context& subctx GGML_CALL static void ggml_backend_vk_free(ggml_backend_t backend); -static void ggml_vk_create_pipeline(vk_device& device, vk_pipeline& pipeline, const std::string& name, size_t spv_size, const void* spv_data, const std::string& entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array wg_denoms, std::vector&& specialization_constants, uint32_t align) { +// variables to track number of compiles in progress +static uint32_t compile_count = 0; +static std::mutex compile_count_mutex; +static std::condition_variable compile_count_cond; + +static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipeline, const std::string name, size_t spv_size, const void* spv_data, const std::string entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array wg_denoms, std::vector specialization_constants, uint32_t align) { VK_LOG_DEBUG("ggml_vk_create_pipeline(" << device->name << ", " << name << ", " << entrypoint << ", " << parameter_count << ", " << push_constant_size << ", (" << wg_denoms[0] << "," << wg_denoms[1] << "," << wg_denoms[2] << "), specialization_constants, " << align << ")"); GGML_ASSERT(parameter_count > 0); GGML_ASSERT(wg_denoms[0] > 0 && wg_denoms[1] > 0 && wg_denoms[2] > 0); // NOLINT - std::lock_guard guard(device->mutex); - pipeline = std::make_shared(); pipeline->name = name; pipeline->parameter_count = parameter_count; @@ -681,7 +686,17 @@ static void ggml_vk_create_pipeline(vk_device& device, vk_pipeline& pipeline, co pipeline->layout); pipeline->pipeline = device->device.createComputePipeline(VK_NULL_HANDLE, compute_pipeline_create_info).value; - device->pipelines.insert({ pipeline->name, pipeline }); + { + std::lock_guard guard(device->mutex); + device->pipelines.insert({ pipeline->name, pipeline }); + } + + { + std::lock_guard guard(compile_count_mutex); + assert(compile_count > 0); + compile_count--; + } + compile_count_cond.notify_all(); } static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline) { @@ -1194,6 +1209,20 @@ static void ggml_vk_load_shaders(vk_device& device) { device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K] = std::make_shared(); device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL] = std::make_shared(); + std::vector> compiles; + auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const std::string &entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array wg_denoms, std::vector&& specialization_constants, uint32_t align) { + { + // wait until fewer than N compiles are in progress + uint32_t N = std::max(1u, std::thread::hardware_concurrency()); + std::unique_lock guard(compile_count_mutex); + while (compile_count >= N) { + compile_count_cond.wait(guard); + } + compile_count++; + } + compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), name, spv_size, spv_data, entrypoint, parameter_count, push_constant_size, wg_denoms, specialization_constants, align)); + }; + if (device->fp16) { ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->l, "matmul_f32_l", matmul_f32_f32_len, matmul_f32_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1); ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->m, "matmul_f32_m", matmul_f32_f32_len, matmul_f32_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1); @@ -1743,6 +1772,10 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_len, im2col_f32_f16_data, "main", 2, sizeof(vk_op_im2col_push_constants), {256, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1); + + for (auto &c : compiles) { + c.wait(); + } } static vk_device ggml_vk_get_device(size_t idx) { From aaa40999251f5d18309b3fddc5a7d576f5fdb4e1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sun, 29 Sep 2024 19:56:17 +0200 Subject: [PATCH 18/33] CUDA: remove bad assert (ggml/972) --- ggml/src/ggml-cuda/im2col.cu | 1 - 1 file changed, 1 deletion(-) diff --git a/ggml/src/ggml-cuda/im2col.cu b/ggml/src/ggml-cuda/im2col.cu index 3d0d8d4e6..16463ab0f 100644 --- a/ggml/src/ggml-cuda/im2col.cu +++ b/ggml/src/ggml-cuda/im2col.cu @@ -69,7 +69,6 @@ void ggml_cuda_op_im2col(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { float * dst_d = (float *)dst->data; cudaStream_t stream = ctx.stream(); - GGML_ASSERT(src0->type == GGML_TYPE_F16); GGML_ASSERT(src1->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32); From d0b1d663e430354ab35853a6e1bce51cc8819376 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 29 Sep 2024 21:16:07 +0300 Subject: [PATCH 19/33] sync : ggml --- scripts/sync-ggml.last | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index 36eeed0cc..aa301462a 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -336c10a4c3c8ec99af484b25a0cddd397a09cdb2 +9a24b8c8c40eab7262d067e91d08df160678df8d From c919d5db39c8a7fcb64737f008e4b105ee0acd20 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 29 Sep 2024 21:18:23 +0300 Subject: [PATCH 20/33] ggml : define missing HWCAP flags (#9684) ggml-ci Co-authored-by: Willy Tarreau --- ggml/src/ggml.c | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index fac4466e3..81b651c6a 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -3687,6 +3687,10 @@ static inline int ggml_up(int n, int m) { #include #endif +#if !defined(HWCAP2_I8MM) +#define HWCAP2_I8MM 0 +#endif + static void ggml_init_arm_arch_features(void) { #if defined(__linux__) && defined(__aarch64__) uint32_t hwcap = getauxval(AT_HWCAP); From 8277a817f18967581b02b2248989d773e8e99998 Mon Sep 17 00:00:00 2001 From: Ruchira Hasaranga Date: Mon, 30 Sep 2024 13:53:42 +0530 Subject: [PATCH 21/33] console : utf-8 fix for windows stdin (#9690) * utf-8 fix for windows stdin * Update common/console.cpp --------- Co-authored-by: Georgi Gerganov --- common/console.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/common/console.cpp b/common/console.cpp index f65cbc6ed..078a8d678 100644 --- a/common/console.cpp +++ b/common/console.cpp @@ -94,6 +94,9 @@ namespace console { simple_io = true; } } + if (simple_io) { + _setmode(_fileno(stdin), _O_U8TEXT); + } #else // POSIX-specific console initialization if (!simple_io) { From ace4f4be37abed4801fbd54a94cf38a7ae462416 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 30 Sep 2024 17:48:49 +0300 Subject: [PATCH 22/33] flake.lock: Update (#9680) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Flake lock file updates: • Updated input 'nixpkgs': 'github:NixOS/nixpkgs/c04d5652cfa9742b1d519688f65d1bbccea9eb7e?narHash=sha256-PmUr/2GQGvFTIJ6/Tvsins7Q43KTMvMFhvG6oaYK%2BWk%3D' (2024-09-19) → 'github:NixOS/nixpkgs/1925c603f17fc89f4c8f6bf6f631a802ad85d784?narHash=sha256-J%2BPeFKSDV%2BpHL7ukkfpVzCOO7mBSrrpJ3svwBFABbhI%3D' (2024-09-26) Co-authored-by: github-actions[bot] --- flake.lock | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/flake.lock b/flake.lock index 6333a09f0..dde1ab527 100644 --- a/flake.lock +++ b/flake.lock @@ -20,11 +20,11 @@ }, "nixpkgs": { "locked": { - "lastModified": 1726755586, - "narHash": "sha256-PmUr/2GQGvFTIJ6/Tvsins7Q43KTMvMFhvG6oaYK+Wk=", + "lastModified": 1727348695, + "narHash": "sha256-J+PeFKSDV+pHL7ukkfpVzCOO7mBSrrpJ3svwBFABbhI=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "c04d5652cfa9742b1d519688f65d1bbccea9eb7e", + "rev": "1925c603f17fc89f4c8f6bf6f631a802ad85d784", "type": "github" }, "original": { From 08a43d05b6ba74de97610ae519450ad9996475e0 Mon Sep 17 00:00:00 2001 From: vb Date: Mon, 30 Sep 2024 17:03:47 +0200 Subject: [PATCH 23/33] py : update transfomers version (#9694) * update transfomers version. * update hfh version. --- examples/server/tests/requirements.txt | 2 +- requirements/requirements-convert_legacy_llama.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/server/tests/requirements.txt b/examples/server/tests/requirements.txt index f2d7e5c57..553954872 100644 --- a/examples/server/tests/requirements.txt +++ b/examples/server/tests/requirements.txt @@ -1,6 +1,6 @@ aiohttp~=3.9.3 behave~=1.2.6 -huggingface_hub~=0.20.3 +huggingface_hub~=0.23.2 numpy~=1.26.4 openai~=1.30.3 prometheus-client~=0.20.0 diff --git a/requirements/requirements-convert_legacy_llama.txt b/requirements/requirements-convert_legacy_llama.txt index 1d07b0952..859204b27 100644 --- a/requirements/requirements-convert_legacy_llama.txt +++ b/requirements/requirements-convert_legacy_llama.txt @@ -1,5 +1,5 @@ numpy~=1.26.4 sentencepiece~=0.2.0 -transformers>=4.40.1,<5.0.0 +transformers>=4.45.1,<5.0.0 gguf>=0.1.0 protobuf>=4.21.0,<5.0.0 From 511636df0c90826b4dd1fc21ff260c19d69a3b5d Mon Sep 17 00:00:00 2001 From: compilade Date: Mon, 30 Sep 2024 14:13:16 -0400 Subject: [PATCH 24/33] ci : reduce severity of unused Pyright ignore comments (#9697) --- .github/workflows/python-type-check.yml | 4 +++- examples/llava/convert_image_encoder_to_gguf.py | 4 ++-- pyrightconfig.json | 3 ++- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/.github/workflows/python-type-check.yml b/.github/workflows/python-type-check.yml index e5ff5e6d7..373bb6010 100644 --- a/.github/workflows/python-type-check.yml +++ b/.github/workflows/python-type-check.yml @@ -4,11 +4,13 @@ on: push: paths: - '.github/workflows/python-type-check.yml' + - 'pyrightconfig.json' - '**.py' - '**/requirements*.txt' pull_request: paths: - '.github/workflows/python-type-check.yml' + - 'pyrightconfig.json' - '**.py' - '**/requirements*.txt' @@ -33,6 +35,6 @@ jobs: - name: Type-check with Pyright uses: jakebailey/pyright-action@v2 with: - version: 1.1.370 + version: 1.1.382 level: warning warnings: true diff --git a/examples/llava/convert_image_encoder_to_gguf.py b/examples/llava/convert_image_encoder_to_gguf.py index 36f6b92fb..4fa1d6cea 100644 --- a/examples/llava/convert_image_encoder_to_gguf.py +++ b/examples/llava/convert_image_encoder_to_gguf.py @@ -274,7 +274,7 @@ fout.add_bool("clip.use_gelu", use_gelu) if has_llava_projector: - model.vision_model.encoder.layers.pop(-1) # pyright: ignore[reportAttributeAccessIssue] + model.vision_model.encoder.layers.pop(-1) projector = torch.load(args.llava_projector) for name, data in projector.items(): name = get_tensor_name(name) @@ -288,7 +288,7 @@ if has_llava_projector: print("Projector tensors added\n") -state_dict = model.state_dict() # pyright: ignore[reportAttributeAccessIssue] +state_dict = model.state_dict() for name, data in state_dict.items(): if should_skip_tensor(name, has_text_encoder, has_vision_encoder, has_llava_projector): # we don't need this diff --git a/pyrightconfig.json b/pyrightconfig.json index 6016f4b6d..9acbbeb78 100644 --- a/pyrightconfig.json +++ b/pyrightconfig.json @@ -5,7 +5,8 @@ "reportUnusedImport": "warning", "reportDuplicateImport": "error", "reportDeprecated": "warning", - "reportUnnecessaryTypeIgnoreComment": "warning", + "reportUnnecessaryTypeIgnoreComment": "information", + "disableBytesTypePromotions": false, // TODO: change once Python 3.12 is the minimum "executionEnvironments": [ { // TODO: make this version override work correctly From 6f1d9d71f4c568778a7637ff6582e6f6ba5fb9d3 Mon Sep 17 00:00:00 2001 From: serhii-nakon <57632032+serhii-nakon@users.noreply.github.com> Date: Mon, 30 Sep 2024 21:57:12 +0300 Subject: [PATCH 25/33] Fix Docker ROCM builds, use AMDGPU_TARGETS instead of GPU_TARGETS (#9641) * Fix Docker ROCM builds, use AMDGPU_TARGETS instead of GPU_TARGETS * Set ROCM_DOCKER_ARCH as string due it incorrectly build and cause OOM exit code --- .devops/full-rocm.Dockerfile | 6 +++--- .devops/llama-cli-rocm.Dockerfile | 6 +++--- .devops/llama-server-rocm.Dockerfile | 6 +++--- .github/workflows/build.yml | 2 +- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/.devops/full-rocm.Dockerfile b/.devops/full-rocm.Dockerfile index 680d1cb92..df496bcd2 100644 --- a/.devops/full-rocm.Dockerfile +++ b/.devops/full-rocm.Dockerfile @@ -11,7 +11,7 @@ FROM ${BASE_ROCM_DEV_CONTAINER} AS build # Unless otherwise specified, we make a fat build. # List from https://github.com/ggerganov/llama.cpp/pull/1087#issuecomment-1682807878 # This is mostly tied to rocBLAS supported archs. -ARG ROCM_DOCKER_ARCH=\ +ARG ROCM_DOCKER_ARCH="\ gfx803 \ gfx900 \ gfx906 \ @@ -21,7 +21,7 @@ ARG ROCM_DOCKER_ARCH=\ gfx1030 \ gfx1100 \ gfx1101 \ - gfx1102 + gfx1102" COPY requirements.txt requirements.txt COPY requirements requirements @@ -34,7 +34,7 @@ WORKDIR /app COPY . . # Set nvcc architecture -ENV GPU_TARGETS=${ROCM_DOCKER_ARCH} +ENV AMDGPU_TARGETS=${ROCM_DOCKER_ARCH} # Enable ROCm ENV GGML_HIPBLAS=1 ENV CC=/opt/rocm/llvm/bin/clang diff --git a/.devops/llama-cli-rocm.Dockerfile b/.devops/llama-cli-rocm.Dockerfile index c3d1ab067..e60c747bd 100644 --- a/.devops/llama-cli-rocm.Dockerfile +++ b/.devops/llama-cli-rocm.Dockerfile @@ -11,7 +11,7 @@ FROM ${BASE_ROCM_DEV_CONTAINER} AS build # Unless otherwise specified, we make a fat build. # List from https://github.com/ggerganov/llama.cpp/pull/1087#issuecomment-1682807878 # This is mostly tied to rocBLAS supported archs. -ARG ROCM_DOCKER_ARCH=\ +ARG ROCM_DOCKER_ARCH="\ gfx803 \ gfx900 \ gfx906 \ @@ -21,7 +21,7 @@ ARG ROCM_DOCKER_ARCH=\ gfx1030 \ gfx1100 \ gfx1101 \ - gfx1102 + gfx1102" COPY requirements.txt requirements.txt COPY requirements requirements @@ -34,7 +34,7 @@ WORKDIR /app COPY . . # Set nvcc architecture -ENV GPU_TARGETS=${ROCM_DOCKER_ARCH} +ENV AMDGPU_TARGETS=${ROCM_DOCKER_ARCH} # Enable ROCm ENV GGML_HIPBLAS=1 ENV CC=/opt/rocm/llvm/bin/clang diff --git a/.devops/llama-server-rocm.Dockerfile b/.devops/llama-server-rocm.Dockerfile index fd0e19ad6..8553af75b 100644 --- a/.devops/llama-server-rocm.Dockerfile +++ b/.devops/llama-server-rocm.Dockerfile @@ -11,7 +11,7 @@ FROM ${BASE_ROCM_DEV_CONTAINER} AS build # Unless otherwise specified, we make a fat build. # List from https://github.com/ggerganov/llama.cpp/pull/1087#issuecomment-1682807878 # This is mostly tied to rocBLAS supported archs. -ARG ROCM_DOCKER_ARCH=\ +ARG ROCM_DOCKER_ARCH="\ gfx803 \ gfx900 \ gfx906 \ @@ -21,7 +21,7 @@ ARG ROCM_DOCKER_ARCH=\ gfx1030 \ gfx1100 \ gfx1101 \ - gfx1102 + gfx1102" COPY requirements.txt requirements.txt COPY requirements requirements @@ -34,7 +34,7 @@ WORKDIR /app COPY . . # Set nvcc architecture -ENV GPU_TARGETS=${ROCM_DOCKER_ARCH} +ENV AMDGPU_TARGETS=${ROCM_DOCKER_ARCH} # Enable ROCm ENV GGML_HIPBLAS=1 ENV CC=/opt/rocm/llvm/bin/clang diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index e6a977b60..c71d422e7 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -1032,7 +1032,7 @@ jobs: run: | $env:HIP_PATH=$(Resolve-Path 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' | split-path | split-path) $env:CMAKE_PREFIX_PATH="${env:HIP_PATH}" - cmake -G "Unix Makefiles" -B build -S . -DCMAKE_C_COMPILER="${env:HIP_PATH}\bin\clang.exe" -DCMAKE_CXX_COMPILER="${env:HIP_PATH}\bin\clang++.exe" -DGGML_HIPBLAS=ON -DCMAKE_BUILD_TYPE=Release -DGPU_TARGETS=${{ matrix.gpu_target }} -DGGML_RPC=ON + cmake -G "Unix Makefiles" -B build -S . -DCMAKE_C_COMPILER="${env:HIP_PATH}\bin\clang.exe" -DCMAKE_CXX_COMPILER="${env:HIP_PATH}\bin\clang++.exe" -DGGML_HIPBLAS=ON -DCMAKE_BUILD_TYPE=Release -DAMDGPU_TARGETS=${{ matrix.gpu_target }} -DGGML_RPC=ON cmake --build build -j ${env:NUMBER_OF_PROCESSORS} md "build\bin\rocblas\library\" cp "${env:HIP_PATH}\bin\hipblas.dll" "build\bin\" From 1927378bcce20ba72b6c89d5977b854a4bcaeb5d Mon Sep 17 00:00:00 2001 From: compilade Date: Tue, 1 Oct 2024 02:31:36 -0400 Subject: [PATCH 26/33] convert : refactor rope_freqs generation (#9396) * convert : refactor rope_freqs generation This should also fix vocab-only conversion for Phi-3. * convert : adapt MiniCPM3 to separate rope_freqs insertion MiniCPM3's tokenizer is treated as a SentencePiece tokenizer to avoid having to run its custom Python code which mixes tokenization in the same file as tool calls. gguf-py : add long and short RoPE factors to tensor mappings Empty, but the key names are used to populate the mappings. --- convert_hf_to_gguf.py | 61 ++++++++++++++++++---------------- convert_lora_to_gguf.py | 5 ++- gguf-py/gguf/constants.py | 4 +++ gguf-py/gguf/tensor_mapping.py | 3 ++ 4 files changed, 44 insertions(+), 29 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index f3857d487..da5feb25b 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -15,6 +15,7 @@ from enum import IntEnum from pathlib import Path from hashlib import sha256 from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Literal, Sequence, TypeVar, cast +from itertools import chain import math import numpy as np @@ -64,7 +65,6 @@ class Model: model_name: str | None metadata_override: Path | None dir_model_card: Path - is_lora: bool # subclasses should define this! model_arch: gguf.MODEL_ARCH @@ -72,7 +72,7 @@ class Model: def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool = False, use_temp_file: bool = False, eager: bool = False, metadata_override: Path | None = None, model_name: str | None = None, - split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False, is_lora: bool = False): + split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False): if type(self) is Model: raise TypeError(f"{type(self).__name__!r} should not be directly instantiated") @@ -94,7 +94,6 @@ class Model: self.metadata_override = metadata_override self.model_name = model_name self.dir_model_card = dir_model # overridden in convert_lora_to_gguf.py - self.is_lora = is_lora # true if model is used inside convert_lora_to_gguf.py # Apply heuristics to figure out typical tensor encoding based on first layer tensor encoding type if self.ftype == gguf.LlamaFileType.GUESSED: @@ -270,10 +269,14 @@ class Model: return False + # some models need extra generated tensors (like rope_freqs) + def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: + return () + def prepare_tensors(self): max_name_len = max(len(s) for _, s in self.tensor_map.mapping.values()) + len(".weight,") - for name, data_torch in self.get_tensors(): + for name, data_torch in chain(self.generate_extra_tensors(), self.get_tensors()): # we don't need these if name.endswith((".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq")): continue @@ -1617,7 +1620,7 @@ class LlamaModel(Model): return [(self.map_tensor_name(name), data_torch)] - def prepare_tensors(self): + def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: if rope_scaling := self.find_hparam(["rope_scaling"], optional=True): if rope_scaling.get("rope_type", '').lower() == "llama3": base = self.hparams.get("rope_theta", 10000.0) @@ -1644,9 +1647,9 @@ class LlamaModel(Model): smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor) rope_factors.append(1 / ((1 - smooth) / factor + smooth)) - if not self.is_lora: - self.gguf_writer.add_tensor(self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), np.array(rope_factors, dtype=np.float32)) + yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), torch.tensor(rope_factors, dtype=torch.float32)) + def prepare_tensors(self): super().prepare_tensors() if self._experts is not None: @@ -1870,8 +1873,6 @@ class MiniCPM3Model(Model): def set_gguf_parameters(self): hparams = self.hparams - rope_dims = hparams["qk_rope_head_dim"] - self.gguf_writer.add_file_type(self.ftype) self.gguf_writer.add_context_length(hparams["max_position_embeddings"]) self.gguf_writer.add_embedding_length(hparams["hidden_size"]) @@ -1887,24 +1888,25 @@ class MiniCPM3Model(Model): self.gguf_writer.add_key_length(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"]) self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"]) + def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: rope_scaling = self.find_hparam(['rope_scaling'], True) - if rope_scaling is None: - return + if rope_scaling is not None: + rope_dims = self.hparams["qk_rope_head_dim"] - long_factors = rope_scaling.get('long_factor', None) - short_factors = rope_scaling.get('short_factor', None) + long_factors = rope_scaling.get('long_factor', None) + short_factors = rope_scaling.get('short_factor', None) - if long_factors is None or short_factors is None: - raise KeyError('Missing the required key rope_scaling.long_factor or rope_scaling_short_factor') + if long_factors is None or short_factors is None: + raise KeyError('Missing the required key rope_scaling.long_factor or rope_scaling_short_factor') - if len(long_factors) != len(short_factors) or len(long_factors) != rope_dims / 2: - raise ValueError(f'The length of rope long and short factors must be {rope_dims / 2}') + if len(long_factors) != len(short_factors) or len(long_factors) != rope_dims / 2: + raise ValueError(f'The length of rope long and short factors must be {rope_dims / 2}') - self.gguf_writer.add_tensor(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ROPE_FACTORS_LONG] + ".weight", np.array(long_factors, dtype=np.float32)) - self.gguf_writer.add_tensor(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT] + ".weight", np.array(short_factors, dtype=np.float32)) + yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_LONG), torch.tensor(long_factors, dtype=torch.float32)) + yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT), torch.tensor(short_factors, dtype=torch.float32)) def set_vocab(self): - self._set_vocab_llama_hf() + self._set_vocab_sentencepiece() def _reverse_hf_permute(self, weights: Tensor, n_head: int, n_kv_head: int | None = None) -> Tensor: if n_kv_head is not None and n_head != n_kv_head: @@ -2216,6 +2218,13 @@ class Phi3MiniModel(Model): self.gguf_writer.add_file_type(self.ftype) self.gguf_writer.add_sliding_window(self.find_hparam(["sliding_window"])) + def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: + n_embd = self.find_hparam(["hidden_size", "n_embd"]) + n_head = self.find_hparam(["num_attention_heads", "n_head"]) + max_pos_embds = self.find_hparam(["n_positions", "max_position_embeddings"]) + orig_max_pos_embds = self.find_hparam(["original_max_position_embeddings"]) + rope_dims = n_embd // n_head + # write rope scaling for long context (128k) model rope_scaling = self.find_hparam(['rope_scaling'], True) if rope_scaling is None: @@ -2245,9 +2254,8 @@ class Phi3MiniModel(Model): if len(long_factors) != len(short_factors) or len(long_factors) != rope_dims / 2: raise ValueError(f'The length of rope long and short factors must be {rope_dims / 2}') - if not self.is_lora: - self.gguf_writer.add_tensor(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ROPE_FACTORS_LONG] + ".weight", np.array(long_factors, dtype=np.float32)) - self.gguf_writer.add_tensor(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT] + ".weight", np.array(short_factors, dtype=np.float32)) + yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_LONG), torch.tensor(long_factors, dtype=torch.float32)) + yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT), torch.tensor(short_factors, dtype=torch.float32)) @Model.register("PlamoForCausalLM") @@ -4071,7 +4079,7 @@ class ExaoneModel(Model): self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) self.gguf_writer.add_rope_scaling_factor(hparams["rope_scaling"]["factor"]) - def prepare_tensors(self): + def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: if rope_scaling := self.find_hparam(["rope_scaling"], optional=True): if rope_scaling.get("rope_type", '').lower() == "llama3": base = self.hparams.get("rope_theta", 10000.0) @@ -4098,10 +4106,7 @@ class ExaoneModel(Model): smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor) rope_factors.append(1 / ((1 - smooth) / factor + smooth)) - if not self.is_lora: - self.gguf_writer.add_tensor(self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), np.array(rope_factors, dtype=np.float32)) - - super().prepare_tensors() + yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), torch.tensor(rope_factors, dtype=torch.float32)) @Model.register("GraniteForCausalLM") diff --git a/convert_lora_to_gguf.py b/convert_lora_to_gguf.py index d1c94e580..439a78de1 100755 --- a/convert_lora_to_gguf.py +++ b/convert_lora_to_gguf.py @@ -331,6 +331,10 @@ if __name__ == '__main__': self.gguf_writer.add_float32(gguf.Keys.Adapter.LORA_ALPHA, self.lora_alpha) super().set_gguf_parameters() + def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: + # Never add extra tensors (e.g. rope_freqs) for LoRA adapters + return () + def get_tensors(self) -> Iterator[tuple[str, Tensor]]: tensor_map: dict[str, PartialLoraTensor] = {} @@ -392,7 +396,6 @@ if __name__ == '__main__': dry_run=args.dry_run, dir_lora_model=dir_lora, lora_alpha=alpha, - is_lora=True, ) logger.info("Exporting model...") diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index ebe66a4a3..e08617ba2 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -814,6 +814,8 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FACTORS_LONG, + MODEL_TENSOR.ROPE_FACTORS_SHORT, MODEL_TENSOR.ATTN_NORM, MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, @@ -892,6 +894,8 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FACTORS_LONG, + MODEL_TENSOR.ROPE_FACTORS_SHORT, MODEL_TENSOR.ATTN_NORM, MODEL_TENSOR.ATTN_Q_A, MODEL_TENSOR.ATTN_Q_B, diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index e7e9b6fd5..f4a787c56 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -87,6 +87,9 @@ class TensorNameMap: "rope.freqs", # llama-pth "rotary_pos_emb.inv_freq", # chatglm ), + + MODEL_TENSOR.ROPE_FACTORS_LONG: (), + MODEL_TENSOR.ROPE_FACTORS_SHORT: (), } block_mappings_cfg: dict[MODEL_TENSOR, tuple[str, ...]] = { From a90484c6d9db699bf739d0f33daf1c50cbdd45c9 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 1 Oct 2024 11:42:01 +0300 Subject: [PATCH 27/33] llama : print correct model type for Llama 3.2 1B and 3B --- src/llama.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/llama.cpp b/src/llama.cpp index c466cd88b..d1d27d21e 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -5502,8 +5502,10 @@ static void llm_load_hparams( } } else { switch (hparams.n_layer) { + case 16: model.type = e_model::MODEL_1B; break; // Llama 3.2 1B case 22: model.type = e_model::MODEL_1B; break; case 26: model.type = e_model::MODEL_3B; break; + case 28: model.type = e_model::MODEL_3B; break; // Llama 3.2 3B // granite uses a vocab with len 49152 case 32: model.type = hparams.n_vocab == 49152 ? e_model::MODEL_3B : (hparams.n_vocab < 40000 ? e_model::MODEL_7B : e_model::MODEL_8B); break; case 36: model.type = e_model::MODEL_8B; break; // granite From cad341d88924788b940997c55e7bef000c99e784 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 1 Oct 2024 16:00:25 +0300 Subject: [PATCH 28/33] metal : reduce command encoding overhead (#9698) * metal : reduce command encoding overhead ggml-ci * metal : add comments --- examples/cvector-generator/pca.hpp | 7 - examples/llava/clip.cpp | 6 - ggml/include/ggml-metal.h | 5 - ggml/src/ggml-metal.m | 3888 ++++++++++++++-------------- src/llama.cpp | 6 - 5 files changed, 2000 insertions(+), 1912 deletions(-) diff --git a/examples/cvector-generator/pca.hpp b/examples/cvector-generator/pca.hpp index a969c486d..f6e307fbc 100644 --- a/examples/cvector-generator/pca.hpp +++ b/examples/cvector-generator/pca.hpp @@ -204,13 +204,6 @@ static ggml_status compute_piter( ggml_backend_cpu_set_n_threads(model.backend, params.n_threads); } -// TODO: enable GPU support when support for GGML_OP_SQRT is added -//#ifdef GGML_USE_METAL -// if (ggml_backend_is_metal(model.backend)) { -// ggml_backend_metal_set_n_cb(model.backend, params.n_threads); -// } -//#endif - ggml_status res = ggml_backend_graph_compute(model.backend, gf); if (res == GGML_STATUS_SUCCESS) { auto extract_i = [](std::string prefix, std::string str) -> int { diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp index 8aa7b0750..14e02c8dd 100644 --- a/examples/llava/clip.cpp +++ b/examples/llava/clip.cpp @@ -2444,12 +2444,6 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima ggml_backend_cpu_set_n_threads(ctx->backend, n_threads); } -#ifdef GGML_USE_METAL - if (ggml_backend_is_metal(ctx->backend)) { - ggml_backend_metal_set_n_cb(ctx->backend, n_threads); - } -#endif - ggml_backend_graph_compute(ctx->backend, gf); // the last node is the embedding tensor diff --git a/ggml/include/ggml-metal.h b/ggml/include/ggml-metal.h index d483cf1ac..4d4165324 100644 --- a/ggml/include/ggml-metal.h +++ b/ggml/include/ggml-metal.h @@ -25,9 +25,6 @@ #include #include -// max memory buffers that can be mapped to the device -#define GGML_METAL_MAX_BUFFERS 64 - struct ggml_tensor; struct ggml_cgraph; @@ -48,8 +45,6 @@ GGML_API bool ggml_backend_is_metal(ggml_backend_t backend); GGML_API GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size); -GGML_API void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb); - GGML_API void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data); GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void); diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index ef3b7f0e8..9da08fe2e 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -12,6 +12,12 @@ #define MIN(a, b) ((a) < (b) ? (a) : (b)) #define MAX(a, b) ((a) > (b) ? (a) : (b)) +// max memory buffers that can be mapped to the device +#define GGML_METAL_MAX_BUFFERS 64 + +// max number of MTLCommandBuffer used to submit a graph for processing +#define GGML_METAL_MAX_COMMAND_BUFFERS 8 + #ifdef GGML_METAL_NDEBUG #define GGML_METAL_LOG(...) #define GGML_METAL_LOG_INFO(...) @@ -221,11 +227,11 @@ enum ggml_metal_kernel_type { }; struct ggml_backend_metal_context { - int n_cb; - id device; id queue; + MTLComputePassDescriptor * edesc; + dispatch_queue_t d_queue; struct ggml_metal_kernel kernels[GGML_METAL_KERNEL_TYPE_COUNT]; @@ -233,7 +239,27 @@ struct ggml_backend_metal_context { bool support_simdgroup_reduction; bool support_simdgroup_mm; - bool should_capture_next_compute; + // capture state + bool capture_next_compute; + bool capture_started; + + id capture_scope; + + // command buffer state + int n_cb; // number of extra threads used to submit the command buffers + int n_nodes_0; // number of nodes submitted by the main thread + int n_nodes_1; // remaining number of nodes submitted by the n_cb threads + int n_nodes_per_cb; + + struct ggml_cgraph * gf; + + // the callback given to the thread pool + // TODO: ideally, this should be created once, utilizing the command buffer state above + // for some reason, doing it like this leads to a crash + void (^encode_async)(size_t ith); + + // n_cb command buffers + 1 used by the main thread + id command_buffers[GGML_METAL_MAX_COMMAND_BUFFERS + 1]; // abort ggml_metal_graph_compute if callback returns true ggml_abort_callback abort_callback; @@ -303,7 +329,7 @@ static void * ggml_metal_host_malloc(size_t n) { return data; } -static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) { +static struct ggml_backend_metal_context * ggml_metal_init(void) { GGML_METAL_LOG_INFO("%s: allocating\n", __func__); #if TARGET_OS_OSX && !GGML_METAL_NDEBUG @@ -322,8 +348,9 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) { // Configure context struct ggml_backend_metal_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_context)); ctx->device = device; - ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS); ctx->queue = [ctx->device newCommandQueue]; + ctx->edesc = MTLComputePassDescriptor.computePassDescriptor; + ctx->edesc.dispatchType = MTLDispatchTypeSerial; ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT); id metal_library; @@ -455,7 +482,15 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_LOG_INFO("%s: simdgroup matrix mul. support = %s\n", __func__, ctx->support_simdgroup_mm ? "true" : "false"); GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false"); - ctx->should_capture_next_compute = false; + ctx->capture_next_compute = false; + ctx->capture_started = false; + ctx->capture_scope = nil; + + ctx->gf = nil; + ctx->encode_async = nil; + for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) { + ctx->command_buffers[i] = nil; + } #if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15) if (@available(macOS 10.12, iOS 16.0, *)) { @@ -686,6 +721,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) { } [metal_library release]; + return ctx; } @@ -874,874 +910,820 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx } } -static enum ggml_status ggml_metal_graph_compute( - struct ggml_backend_metal_context * ctx, - struct ggml_cgraph * gf) { +static void ggml_metal_encode_node( + struct ggml_backend_metal_context * ctx, + int idx, + id encoder) { + struct ggml_cgraph * gf = ctx->gf; - @autoreleasepool { - MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor; - edesc.dispatchType = MTLDispatchTypeSerial; + struct ggml_tensor * node = ggml_graph_node(gf, idx); - // create multiple command buffers and enqueue them - // then, we encode the graph into the command buffers in parallel + //GGML_METAL_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, idx, ggml_op_name(node->op)); - const int n_nodes = gf->n_nodes; - const int n_cb = ctx->n_cb; - const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb; + struct ggml_tensor * src0 = node->src[0]; + struct ggml_tensor * src1 = node->src[1]; + struct ggml_tensor * src2 = node->src[2]; + struct ggml_tensor * dst = node; - const bool should_capture = ctx->should_capture_next_compute; - if (should_capture) { - ctx->should_capture_next_compute = false; - - MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new]; - descriptor.captureObject = ctx->queue; - - NSError * error = nil; - if (![[MTLCaptureManager sharedCaptureManager] startCaptureWithDescriptor:descriptor error:&error]) { - GGML_METAL_LOG_ERROR("%s: error: unable to start capture '%s'\n", __func__, [[error localizedDescription] UTF8String]); - GGML_ABORT("capture failed"); - } + if (ggml_is_empty(dst)) { + return; } - id command_buffer_builder[n_cb]; - for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) { - id command_buffer = [ctx->queue commandBufferWithUnretainedReferences]; - command_buffer_builder[cb_idx] = command_buffer; - - // always enqueue the first two command buffers - // enqueue all of the command buffers if we don't need to abort - if (cb_idx < 2 || ctx->abort_callback == NULL) { - [command_buffer enqueue]; - } + switch (dst->op) { + case GGML_OP_NONE: + case GGML_OP_RESHAPE: + case GGML_OP_VIEW: + case GGML_OP_TRANSPOSE: + case GGML_OP_PERMUTE: + { + // noop -> next node + } return; + default: + { + } break; } - const id *command_buffers = command_buffer_builder; + if (!ggml_metal_supports_op(ctx, dst)) { + GGML_METAL_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst)); + GGML_ABORT("unsupported op"); + } - dispatch_apply(n_cb, ctx->d_queue, ^(size_t iter) { - const int cb_idx = iter; + const int64_t ne00 = src0 ? src0->ne[0] : 0; + const int64_t ne01 = src0 ? src0->ne[1] : 0; + const int64_t ne02 = src0 ? src0->ne[2] : 0; + const int64_t ne03 = src0 ? src0->ne[3] : 0; - size_t offs_src0 = 0; - size_t offs_src1 = 0; - size_t offs_src2 = 0; - size_t offs_dst = 0; + const uint64_t nb00 = src0 ? src0->nb[0] : 0; + const uint64_t nb01 = src0 ? src0->nb[1] : 0; + const uint64_t nb02 = src0 ? src0->nb[2] : 0; + const uint64_t nb03 = src0 ? src0->nb[3] : 0; - id command_buffer = command_buffers[cb_idx]; - id encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc]; + const int64_t ne10 = src1 ? src1->ne[0] : 0; + const int64_t ne11 = src1 ? src1->ne[1] : 0; + const int64_t ne12 = src1 ? src1->ne[2] : 0; + const int64_t ne13 = src1 ? src1->ne[3] : 0; - const int node_start = (cb_idx + 0) * n_nodes_per_cb; - const int node_end = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes); + const uint64_t nb10 = src1 ? src1->nb[0] : 0; + const uint64_t nb11 = src1 ? src1->nb[1] : 0; + const uint64_t nb12 = src1 ? src1->nb[2] : 0; + const uint64_t nb13 = src1 ? src1->nb[3] : 0; - for (int i = node_start; i < node_end; ++i) { - if (i == -1) { - [encoder memoryBarrierWithScope:MTLBarrierScopeBuffers]; - continue; - } + const int64_t ne20 = src2 ? src2->ne[0] : 0; + const int64_t ne21 = src2 ? src2->ne[1] : 0; + const int64_t ne22 = src2 ? src2->ne[2] : 0; GGML_UNUSED(ne22); + const int64_t ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23); - //GGML_METAL_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op)); + const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20); + const uint64_t nb21 = src2 ? src2->nb[1] : 0; + const uint64_t nb22 = src2 ? src2->nb[2] : 0; + const uint64_t nb23 = src2 ? src2->nb[3] : 0; - struct ggml_tensor * src0 = gf->nodes[i]->src[0]; - struct ggml_tensor * src1 = gf->nodes[i]->src[1]; - struct ggml_tensor * src2 = gf->nodes[i]->src[2]; - struct ggml_tensor * dst = gf->nodes[i]; + const int64_t ne0 = dst ? dst->ne[0] : 0; + const int64_t ne1 = dst ? dst->ne[1] : 0; + const int64_t ne2 = dst ? dst->ne[2] : 0; + const int64_t ne3 = dst ? dst->ne[3] : 0; - if (ggml_is_empty(dst)) { - continue; - } + const uint64_t nb0 = dst ? dst->nb[0] : 0; + const uint64_t nb1 = dst ? dst->nb[1] : 0; + const uint64_t nb2 = dst ? dst->nb[2] : 0; + const uint64_t nb3 = dst ? dst->nb[3] : 0; - switch (dst->op) { - case GGML_OP_NONE: - case GGML_OP_RESHAPE: - case GGML_OP_VIEW: - case GGML_OP_TRANSPOSE: - case GGML_OP_PERMUTE: - { - // noop -> next node - } continue; + const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT; + const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT; + const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT; + + size_t offs_src0 = 0; + size_t offs_src1 = 0; + size_t offs_src2 = 0; + size_t offs_dst = 0; + + id id_src0 = src0 ? ggml_metal_get_buffer(src0, &offs_src0) : nil; + id id_src1 = src1 ? ggml_metal_get_buffer(src1, &offs_src1) : nil; + id id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil; + id id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil; + + //GGML_METAL_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op)); + //if (src0) { + // GGML_METAL_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02, + // ggml_is_contiguous(src0), src0->name); + //} + //if (src1) { + // GGML_METAL_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12, + // ggml_is_contiguous(src1), src1->name); + //} + //if (dst) { + // GGML_METAL_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2, + // dst->name); + //} + + switch (dst->op) { + case GGML_OP_CONCAT: + { + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline; + + const int32_t dim = ((const int32_t *) dst->op_params)[0]; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; + [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10]; + [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11]; + [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13]; + [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17]; + [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20]; + [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21]; + [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22]; + [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24]; + [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25]; + [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26]; + [encoder setBytes:&dim length:sizeof(dim) atIndex:27]; + + const int nth = MIN(1024, ne0); + + [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_ADD: + case GGML_OP_SUB: + case GGML_OP_MUL: + case GGML_OP_DIV: + { + GGML_ASSERT(src0t == GGML_TYPE_F32); + GGML_ASSERT(src1t == GGML_TYPE_F32); + + const size_t offs = 0; + + bool bcast_row = false; + + int64_t nb = ne00; // used by the "row" kernels + + id pipeline = nil; + + if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) { + GGML_ASSERT(ggml_is_contiguous(src0)); + + // src1 is a row + GGML_ASSERT(ne11 == 1); + + nb = ne00 / 4; + switch (dst->op) { + case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break; + case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB_ROW].pipeline; break; + case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline; break; + case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline; break; + default: GGML_ABORT("fatal error"); + } + + bcast_row = true; + } else { + switch (dst->op) { + case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; break; + case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB].pipeline; break; + case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break; + case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break; + default: GGML_ABORT("fatal error"); + } + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; + [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10]; + [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11]; + [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13]; + [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17]; + [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20]; + [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21]; + [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22]; + [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24]; + [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25]; + [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26]; + [encoder setBytes:&offs length:sizeof(offs) atIndex:27]; + [encoder setBytes:&nb length:sizeof(nb) atIndex:28]; + + if (bcast_row) { + const int64_t n = ggml_nelements(dst)/4; + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } else { + const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0); + + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } + } break; + case GGML_OP_REPEAT: + { + id pipeline; + + switch (src0t) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F16].pipeline; break; + case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I32].pipeline; break; + case GGML_TYPE_I16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I16].pipeline; break; + default: GGML_ABORT("fatal error"); + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; + [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11]; + [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12]; + [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13]; + [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15]; + [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16]; + [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17]; + + const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0); + + [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_ACC: + { + GGML_ASSERT(src0t == GGML_TYPE_F32); + GGML_ASSERT(src1t == GGML_TYPE_F32); + GGML_ASSERT(dstt == GGML_TYPE_F32); + + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); + + const size_t pnb1 = ((const int32_t *) dst->op_params)[0]; + const size_t pnb2 = ((const int32_t *) dst->op_params)[1]; + const size_t pnb3 = ((const int32_t *) dst->op_params)[2]; + const size_t offs = ((const int32_t *) dst->op_params)[3]; + + const bool inplace = (bool) ((const int32_t *) dst->op_params)[4]; + + if (!inplace) { + // run a separete kernel to cpy src->dst + // not sure how to avoid this + // TODO: make a simpler cpy_bytes kernel + + const id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; + [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; + [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; + [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; + [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; + [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; + [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; + [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; + [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; + [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; + [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; + [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; + + const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00); + + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } + + const id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; + [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7]; + [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:8]; + [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:9]; + [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:10]; + [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11]; + [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13]; + [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17]; + [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20]; + [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21]; + [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22]; + [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23]; + [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:24]; + [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:25]; + [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26]; + [encoder setBytes:&offs length:sizeof(offs) atIndex:27]; + + const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00); + + [encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_SCALE: + { + GGML_ASSERT(ggml_is_contiguous(src0)); + + float scale; + memcpy(&scale, dst->op_params, sizeof(scale)); + + int64_t n = ggml_nelements(dst); + + id pipeline = nil; + + if (n % 4 == 0) { + n /= 4; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE_4].pipeline; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE].pipeline; + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&scale length:sizeof(scale) atIndex:2]; + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_CLAMP: + { + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CLAMP].pipeline; + + float min; + float max; + memcpy(&min, ((const int32_t *) dst->op_params) + 0, sizeof(float)); + memcpy(&max, ((const int32_t *) dst->op_params) + 1, sizeof(float)); + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&min length:sizeof(min) atIndex:2]; + [encoder setBytes:&max length:sizeof(max) atIndex:3]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_UNARY: + switch (ggml_get_unary_op(node)) { + // we are not taking into account the strides, so for now require contiguous tensors + GGML_ASSERT(ggml_is_contiguous(src0)); + + case GGML_UNARY_OP_TANH: + { + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TANH].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_UNARY_OP_RELU: + { + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RELU].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_UNARY_OP_SIGMOID: + { + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SIGMOID].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_UNARY_OP_GELU: + { + int64_t n = ggml_nelements(dst); + + id pipeline = nil; + + if (n % 4 == 0) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_4].pipeline; + n /= 4; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU].pipeline; + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_UNARY_OP_GELU_QUICK: + { + int64_t n = ggml_nelements(dst); + + id pipeline = nil; + + if (n % 4 == 0) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK_4].pipeline; + n /= 4; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline; + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_UNARY_OP_SILU: + { + int64_t n = ggml_nelements(dst); + + id pipeline = nil; + + if (n % 4 == 0) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU_4].pipeline; + n /= 4; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU].pipeline; + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; default: - { - } break; - } - - if (!ggml_metal_supports_op(ctx, dst)) { - GGML_METAL_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst)); - GGML_ABORT("unsupported op"); - } - - if (should_capture) { - [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(dst) encoding:NSUTF8StringEncoding]]; - } - - const int64_t ne00 = src0 ? src0->ne[0] : 0; - const int64_t ne01 = src0 ? src0->ne[1] : 0; - const int64_t ne02 = src0 ? src0->ne[2] : 0; - const int64_t ne03 = src0 ? src0->ne[3] : 0; - - const uint64_t nb00 = src0 ? src0->nb[0] : 0; - const uint64_t nb01 = src0 ? src0->nb[1] : 0; - const uint64_t nb02 = src0 ? src0->nb[2] : 0; - const uint64_t nb03 = src0 ? src0->nb[3] : 0; - - const int64_t ne10 = src1 ? src1->ne[0] : 0; - const int64_t ne11 = src1 ? src1->ne[1] : 0; - const int64_t ne12 = src1 ? src1->ne[2] : 0; - const int64_t ne13 = src1 ? src1->ne[3] : 0; - - const uint64_t nb10 = src1 ? src1->nb[0] : 0; - const uint64_t nb11 = src1 ? src1->nb[1] : 0; - const uint64_t nb12 = src1 ? src1->nb[2] : 0; - const uint64_t nb13 = src1 ? src1->nb[3] : 0; - - const int64_t ne20 = src2 ? src2->ne[0] : 0; - const int64_t ne21 = src2 ? src2->ne[1] : 0; - const int64_t ne22 = src2 ? src2->ne[2] : 0; GGML_UNUSED(ne22); - const int64_t ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23); - - const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20); - const uint64_t nb21 = src2 ? src2->nb[1] : 0; - const uint64_t nb22 = src2 ? src2->nb[2] : 0; - const uint64_t nb23 = src2 ? src2->nb[3] : 0; - - const int64_t ne0 = dst ? dst->ne[0] : 0; - const int64_t ne1 = dst ? dst->ne[1] : 0; - const int64_t ne2 = dst ? dst->ne[2] : 0; - const int64_t ne3 = dst ? dst->ne[3] : 0; - - const uint64_t nb0 = dst ? dst->nb[0] : 0; - const uint64_t nb1 = dst ? dst->nb[1] : 0; - const uint64_t nb2 = dst ? dst->nb[2] : 0; - const uint64_t nb3 = dst ? dst->nb[3] : 0; - - const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT; - const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT; - const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT; - - id id_src0 = src0 ? ggml_metal_get_buffer(src0, &offs_src0) : nil; - id id_src1 = src1 ? ggml_metal_get_buffer(src1, &offs_src1) : nil; - id id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil; - id id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil; - - //GGML_METAL_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op)); - //if (src0) { - // GGML_METAL_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02, - // ggml_is_contiguous(src0), src0->name); - //} - //if (src1) { - // GGML_METAL_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12, - // ggml_is_contiguous(src1), src1->name); - //} - //if (dst) { - // GGML_METAL_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2, - // dst->name); - //} - - switch (dst->op) { - case GGML_OP_CONCAT: - { - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline; - - const int32_t dim = ((int32_t *) dst->op_params)[0]; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; - [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21]; - [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24]; - [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25]; - [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26]; - [encoder setBytes:&dim length:sizeof(dim) atIndex:27]; - - const int nth = MIN(1024, ne0); - - [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_ADD: - case GGML_OP_SUB: - case GGML_OP_MUL: - case GGML_OP_DIV: - { - GGML_ASSERT(src0t == GGML_TYPE_F32); - GGML_ASSERT(src1t == GGML_TYPE_F32); - - const size_t offs = 0; - - bool bcast_row = false; - - int64_t nb = ne00; // used by the "row" kernels - - id pipeline = nil; - - if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) { - GGML_ASSERT(ggml_is_contiguous(src0)); - - // src1 is a row - GGML_ASSERT(ne11 == 1); - - nb = ne00 / 4; - switch (dst->op) { - case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break; - case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB_ROW].pipeline; break; - case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline; break; - case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline; break; - default: GGML_ABORT("fatal error"); - } - - bcast_row = true; - } else { - switch (dst->op) { - case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; break; - case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB].pipeline; break; - case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break; - case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break; - default: GGML_ABORT("fatal error"); - } - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; - [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21]; - [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24]; - [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25]; - [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26]; - [encoder setBytes:&offs length:sizeof(offs) atIndex:27]; - [encoder setBytes:&nb length:sizeof(nb) atIndex:28]; - - if (bcast_row) { - const int64_t n = ggml_nelements(dst)/4; - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } else { - const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0); - - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } - } break; - case GGML_OP_REPEAT: - { - id pipeline; - - switch (src0t) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F32].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F16].pipeline; break; - case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I32].pipeline; break; - case GGML_TYPE_I16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I16].pipeline; break; - default: GGML_ABORT("fatal error"); - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12]; - [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15]; - [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16]; - [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17]; - - const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0); - - [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_ACC: - { - GGML_ASSERT(src0t == GGML_TYPE_F32); - GGML_ASSERT(src1t == GGML_TYPE_F32); - GGML_ASSERT(dstt == GGML_TYPE_F32); - - GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(ggml_is_contiguous(src1)); - - const size_t pnb1 = ((int32_t *) dst->op_params)[0]; - const size_t pnb2 = ((int32_t *) dst->op_params)[1]; - const size_t pnb3 = ((int32_t *) dst->op_params)[2]; - const size_t offs = ((int32_t *) dst->op_params)[3]; - - const bool inplace = (bool) ((int32_t *) dst->op_params)[4]; - - if (!inplace) { - // run a separete kernel to cpy src->dst - // not sure how to avoid this - // TODO: make a simpler cpy_bytes kernel - - const id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; - [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; - [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; - [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; - - const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00); - - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } - - const id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; - [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7]; - [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:8]; - [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:9]; - [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:10]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21]; - [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23]; - [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:24]; - [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:25]; - [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26]; - [encoder setBytes:&offs length:sizeof(offs) atIndex:27]; - - const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00); - - [encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_SCALE: - { - GGML_ASSERT(ggml_is_contiguous(src0)); - - float scale; - memcpy(&scale, dst->op_params, sizeof(scale)); - - int64_t n = ggml_nelements(dst); - - id pipeline = nil; - - if (n % 4 == 0) { - n /= 4; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE_4].pipeline; - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE].pipeline; - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&scale length:sizeof(scale) atIndex:2]; - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_OP_CLAMP: - { - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CLAMP].pipeline; - - float min; - float max; - memcpy(&min, ((int32_t *) dst->op_params) + 0, sizeof(float)); - memcpy(&max, ((int32_t *) dst->op_params) + 1, sizeof(float)); - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&min length:sizeof(min) atIndex:2]; - [encoder setBytes:&max length:sizeof(max) atIndex:3]; - - const int64_t n = ggml_nelements(dst); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_OP_UNARY: - switch (ggml_get_unary_op(gf->nodes[i])) { - // we are not taking into account the strides, so for now require contiguous tensors - GGML_ASSERT(ggml_is_contiguous(src0)); - - case GGML_UNARY_OP_TANH: - { - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TANH].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - const int64_t n = ggml_nelements(dst); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_UNARY_OP_RELU: - { - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RELU].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - const int64_t n = ggml_nelements(dst); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_UNARY_OP_SIGMOID: - { - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SIGMOID].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - const int64_t n = ggml_nelements(dst); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_UNARY_OP_GELU: - { - int64_t n = ggml_nelements(dst); - - id pipeline = nil; - - if (n % 4 == 0) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_4].pipeline; - n /= 4; - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU].pipeline; - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_UNARY_OP_GELU_QUICK: - { - int64_t n = ggml_nelements(dst); - - id pipeline = nil; - - if (n % 4 == 0) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK_4].pipeline; - n /= 4; - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline; - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_UNARY_OP_SILU: - { - int64_t n = ggml_nelements(dst); - - id pipeline = nil; - - if (n % 4 == 0) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU_4].pipeline; - n /= 4; - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU].pipeline; - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - default: - { - GGML_METAL_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op)); - GGML_ABORT("fatal error"); - } - } break; - case GGML_OP_SQR: - { - GGML_ASSERT(ggml_is_contiguous(src0)); - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SQR].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - const int64_t n = ggml_nelements(dst); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_OP_SQRT: - { - GGML_ASSERT(ggml_is_contiguous(src0)); - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SQRT].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - const int64_t n = ggml_nelements(dst); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_OP_SIN: - { - GGML_ASSERT(ggml_is_contiguous(src0)); - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SIN].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - const int64_t n = ggml_nelements(dst); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_OP_COS: - { - GGML_ASSERT(ggml_is_contiguous(src0)); - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_COS].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - const int64_t n = ggml_nelements(dst); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_OP_SUM_ROWS: - { - GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type)); - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:19]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:20]; - [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:21]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:22]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:23]; - [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:24]; - [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:25]; - - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_OP_SOFT_MAX: - { - GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); - - int nth = 32; // SIMD width - - id pipeline = nil; - - const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); - - if (ne00%4 == 0) { - while (nth < ne00/4 && nth*ne01*ne02*ne03 < 256) { - nth *= 2; - } - if (use_f16) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4].pipeline; - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4].pipeline; - } - } else { - while (nth < ne00 && nth*ne01*ne02*ne03 < 256) { - nth *= 2; - } - if (use_f16) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16].pipeline; - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32].pipeline; - } - } - - float scale; - float max_bias; - - memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale)); - memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias)); - - const int64_t nrows_x = ggml_nrows(src0); - const int64_t nrows_y = src0->ne[1]; - - const uint32_t n_head = nrows_x/nrows_y; - const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); - - const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); - const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - if (id_src1) { - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - } else { - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; - } - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; - [encoder setBytes:&scale length:sizeof(scale) atIndex:6]; - [encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:7]; - [encoder setBytes:&m0 length:sizeof(m0) atIndex:8]; - [encoder setBytes:&m1 length:sizeof(m1) atIndex:9]; - [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:10]; - [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; - - [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_DIAG_MASK_INF: - { - const int n_past = ((int32_t *)(dst->op_params))[0]; - - id pipeline = nil; - - if (ne00%8 == 0) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8].pipeline; - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline; - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; - [encoder setBytes:&n_past length:sizeof(int) atIndex:4]; - - if (ne00%8 == 0) { - [encoder dispatchThreadgroups:MTLSizeMake(ne00*ne01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } - else { - [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } - } break; - case GGML_OP_SSM_CONV: - { - GGML_ASSERT(src0t == GGML_TYPE_F32); - GGML_ASSERT(src1t == GGML_TYPE_F32); - - GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(ggml_is_contiguous(src1)); - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_CONV_F32].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:11]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:12]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:15]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:16]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:17]; - [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:18]; - - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne1, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_OP_SSM_SCAN: - { - struct ggml_tensor * src3 = gf->nodes[i]->src[3]; - struct ggml_tensor * src4 = gf->nodes[i]->src[4]; - struct ggml_tensor * src5 = gf->nodes[i]->src[5]; - - GGML_ASSERT(src3); - GGML_ASSERT(src4); - GGML_ASSERT(src5); - - size_t offs_src3 = 0; - size_t offs_src4 = 0; - size_t offs_src5 = 0; - - id id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil; - id id_src4 = src4 ? ggml_metal_get_buffer(src4, &offs_src4) : nil; - id id_src5 = src5 ? ggml_metal_get_buffer(src5, &offs_src5) : nil; - - const int64_t ne30 = src3->ne[0]; GGML_UNUSED(ne30); - const int64_t ne31 = src3->ne[1]; GGML_UNUSED(ne31); - - const uint64_t nb30 = src3->nb[0]; - const uint64_t nb31 = src3->nb[1]; - - const int64_t ne40 = src4->ne[0]; GGML_UNUSED(ne40); - const int64_t ne41 = src4->ne[1]; GGML_UNUSED(ne41); - const int64_t ne42 = src4->ne[2]; GGML_UNUSED(ne42); - - const uint64_t nb40 = src4->nb[0]; - const uint64_t nb41 = src4->nb[1]; - const uint64_t nb42 = src4->nb[2]; - - const int64_t ne50 = src5->ne[0]; GGML_UNUSED(ne50); - const int64_t ne51 = src5->ne[1]; GGML_UNUSED(ne51); - const int64_t ne52 = src5->ne[2]; GGML_UNUSED(ne52); - - const uint64_t nb50 = src5->nb[0]; - const uint64_t nb51 = src5->nb[1]; - const uint64_t nb52 = src5->nb[2]; - - const int64_t d_state = ne00; - const int64_t d_inner = ne01; - const int64_t n_seq_tokens = ne11; - const int64_t n_seqs = ne02; - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; - [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; - [encoder setBuffer:id_src4 offset:offs_src4 atIndex:4]; - [encoder setBuffer:id_src5 offset:offs_src5 atIndex:5]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:6]; - - [encoder setBytes:&d_state length:sizeof(d_state) atIndex:7]; - [encoder setBytes:&d_inner length:sizeof(d_inner) atIndex:8]; - [encoder setBytes:&n_seq_tokens length:sizeof(n_seq_tokens) atIndex:9]; - [encoder setBytes:&n_seqs length:sizeof(n_seqs) atIndex:10]; - - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:11]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:12]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:13]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17]; - [encoder setBytes:&nb20 length:sizeof(nb20) atIndex:18]; - [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:19]; - [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:20]; - [encoder setBytes:&nb30 length:sizeof(nb30) atIndex:21]; - [encoder setBytes:&nb31 length:sizeof(nb31) atIndex:22]; - [encoder setBytes:&nb40 length:sizeof(nb40) atIndex:23]; - [encoder setBytes:&nb41 length:sizeof(nb41) atIndex:24]; - [encoder setBytes:&nb42 length:sizeof(nb42) atIndex:25]; - [encoder setBytes:&nb50 length:sizeof(nb50) atIndex:26]; - [encoder setBytes:&nb51 length:sizeof(nb51) atIndex:27]; - [encoder setBytes:&nb52 length:sizeof(nb52) atIndex:28]; - - [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_OP_MUL_MAT: - { - GGML_ASSERT(ne00 == ne10); - - GGML_ASSERT(ne12 % ne02 == 0); - GGML_ASSERT(ne13 % ne03 == 0); - - const uint r2 = ne12/ne02; - const uint r3 = ne13/ne03; - - // find the break-even point where the matrix-matrix kernel becomes more efficient compared - // to the matrix-vector kernel - int ne11_mm_min = 1; + { + GGML_METAL_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op)); + GGML_ABORT("fatal error"); + } + } break; + case GGML_OP_SQR: + { + GGML_ASSERT(ggml_is_contiguous(src0)); + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SQR].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_SQRT: + { + GGML_ASSERT(ggml_is_contiguous(src0)); + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SQRT].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_SIN: + { + GGML_ASSERT(ggml_is_contiguous(src0)); + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SIN].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_COS: + { + GGML_ASSERT(ggml_is_contiguous(src0)); + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_COS].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_SUM_ROWS: + { + GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type)); + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; + [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; + [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10]; + [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12]; + [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16]; + [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:19]; + [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:20]; + [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:21]; + [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:22]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:23]; + [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:24]; + [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:25]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_SOFT_MAX: + { + GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); + + int nth = 32; // SIMD width + + id pipeline = nil; + + const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); + + if (ne00%4 == 0) { + while (nth < ne00/4 && nth*ne01*ne02*ne03 < 256) { + nth *= 2; + } + if (use_f16) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4].pipeline; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4].pipeline; + } + } else { + while (nth < ne00 && nth*ne01*ne02*ne03 < 256) { + nth *= 2; + } + if (use_f16) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16].pipeline; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32].pipeline; + } + } + + float scale; + float max_bias; + + memcpy(&scale, ((const int32_t *) dst->op_params) + 0, sizeof(scale)); + memcpy(&max_bias, ((const int32_t *) dst->op_params) + 1, sizeof(max_bias)); + + const int64_t nrows_x = ggml_nrows(src0); + const int64_t nrows_y = src0->ne[1]; + + const uint32_t n_head = nrows_x/nrows_y; + const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + if (id_src1) { + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + } else { + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + } + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; + [encoder setBytes:&scale length:sizeof(scale) atIndex:6]; + [encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:7]; + [encoder setBytes:&m0 length:sizeof(m0) atIndex:8]; + [encoder setBytes:&m1 length:sizeof(m1) atIndex:9]; + [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:10]; + [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_DIAG_MASK_INF: + { + const int n_past = ((const int32_t *)(dst->op_params))[0]; + + id pipeline = nil; + + if (ne00%8 == 0) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8].pipeline; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline; + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; + [encoder setBytes:&n_past length:sizeof(int) atIndex:4]; + + if (ne00%8 == 0) { + [encoder dispatchThreadgroups:MTLSizeMake(ne00*ne01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } + else { + [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } + } break; + case GGML_OP_SSM_CONV: + { + GGML_ASSERT(src0t == GGML_TYPE_F32); + GGML_ASSERT(src1t == GGML_TYPE_F32); + + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_CONV_F32].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; + [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9]; + [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:11]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:12]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14]; + [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:15]; + [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:16]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:17]; + [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:18]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne1, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_SSM_SCAN: + { + struct ggml_tensor * src3 = node->src[3]; + struct ggml_tensor * src4 = node->src[4]; + struct ggml_tensor * src5 = node->src[5]; + + GGML_ASSERT(src3); + GGML_ASSERT(src4); + GGML_ASSERT(src5); + + size_t offs_src3 = 0; + size_t offs_src4 = 0; + size_t offs_src5 = 0; + + id id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil; + id id_src4 = src4 ? ggml_metal_get_buffer(src4, &offs_src4) : nil; + id id_src5 = src5 ? ggml_metal_get_buffer(src5, &offs_src5) : nil; + + const int64_t ne30 = src3->ne[0]; GGML_UNUSED(ne30); + const int64_t ne31 = src3->ne[1]; GGML_UNUSED(ne31); + + const uint64_t nb30 = src3->nb[0]; + const uint64_t nb31 = src3->nb[1]; + + const int64_t ne40 = src4->ne[0]; GGML_UNUSED(ne40); + const int64_t ne41 = src4->ne[1]; GGML_UNUSED(ne41); + const int64_t ne42 = src4->ne[2]; GGML_UNUSED(ne42); + + const uint64_t nb40 = src4->nb[0]; + const uint64_t nb41 = src4->nb[1]; + const uint64_t nb42 = src4->nb[2]; + + const int64_t ne50 = src5->ne[0]; GGML_UNUSED(ne50); + const int64_t ne51 = src5->ne[1]; GGML_UNUSED(ne51); + const int64_t ne52 = src5->ne[2]; GGML_UNUSED(ne52); + + const uint64_t nb50 = src5->nb[0]; + const uint64_t nb51 = src5->nb[1]; + const uint64_t nb52 = src5->nb[2]; + + const int64_t d_state = ne00; + const int64_t d_inner = ne01; + const int64_t n_seq_tokens = ne11; + const int64_t n_seqs = ne02; + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; + [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; + [encoder setBuffer:id_src4 offset:offs_src4 atIndex:4]; + [encoder setBuffer:id_src5 offset:offs_src5 atIndex:5]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:6]; + + [encoder setBytes:&d_state length:sizeof(d_state) atIndex:7]; + [encoder setBytes:&d_inner length:sizeof(d_inner) atIndex:8]; + [encoder setBytes:&n_seq_tokens length:sizeof(n_seq_tokens) atIndex:9]; + [encoder setBytes:&n_seqs length:sizeof(n_seqs) atIndex:10]; + + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:11]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:12]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:13]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16]; + [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17]; + [encoder setBytes:&nb20 length:sizeof(nb20) atIndex:18]; + [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:19]; + [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:20]; + [encoder setBytes:&nb30 length:sizeof(nb30) atIndex:21]; + [encoder setBytes:&nb31 length:sizeof(nb31) atIndex:22]; + [encoder setBytes:&nb40 length:sizeof(nb40) atIndex:23]; + [encoder setBytes:&nb41 length:sizeof(nb41) atIndex:24]; + [encoder setBytes:&nb42 length:sizeof(nb42) atIndex:25]; + [encoder setBytes:&nb50 length:sizeof(nb50) atIndex:26]; + [encoder setBytes:&nb51 length:sizeof(nb51) atIndex:27]; + [encoder setBytes:&nb52 length:sizeof(nb52) atIndex:28]; + + [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_MUL_MAT: + { + GGML_ASSERT(ne00 == ne10); + + GGML_ASSERT(ne12 % ne02 == 0); + GGML_ASSERT(ne13 % ne03 == 0); + + const uint r2 = ne12/ne02; + const uint r3 = ne13/ne03; + + // find the break-even point where the matrix-matrix kernel becomes more efficient compared + // to the matrix-vector kernel + int ne11_mm_min = 1; #if 0 - // the numbers below are measured on M2 Ultra for 7B and 13B models - // these numbers do not translate to other devices or model sizes - // TODO: need to find a better approach + // the numbers below are measured on M2 Ultra for 7B and 13B models + // these numbers do not translate to other devices or model sizes + // TODO: need to find a better approach if ([ctx->device.name isEqualToString:@"Apple M2 Ultra"]) { switch (src0t) { case GGML_TYPE_F16: ne11_mm_min = 2; break; @@ -1763,11 +1745,11 @@ static enum ggml_status ggml_metal_graph_compute( // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && - !ggml_is_transposed(src0) && - !ggml_is_transposed(src1) && - src1t == GGML_TYPE_F32 && - ne00 % 32 == 0 && ne00 >= 64 && - (ne11 > ne11_mm_min || (ggml_is_quantized(src0t) && ne12 > 1))) { + !ggml_is_transposed(src0) && + !ggml_is_transposed(src1) && + src1t == GGML_TYPE_F32 && + ne00 % 32 == 0 && ne00 >= 64 && + (ne11 > ne11_mm_min || (ggml_is_quantized(src0t) && ne12 > 1))) { //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); // some Metal matrix data types require aligned pointers @@ -2001,8 +1983,8 @@ static enum ggml_status ggml_metal_graph_compute( [encoder setBytes:&r3 length:sizeof(r3) atIndex:18]; if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 || - src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K || - src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) { + src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K || + src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) { [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) { @@ -2036,1041 +2018,1158 @@ static enum ggml_status ggml_metal_graph_compute( [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } } - } break; - case GGML_OP_MUL_MAT_ID: - { - const int n_as = src0->ne[2]; - - // src2 = ids - const enum ggml_type src2t = src2->type; GGML_UNUSED(src2t); - - GGML_ASSERT(src2t == GGML_TYPE_I32); - - GGML_ASSERT(!ggml_is_transposed(src0)); - GGML_ASSERT(!ggml_is_transposed(src1)); - - GGML_ASSERT(src1t == GGML_TYPE_F32); - - // find the break-even point where the matrix-matrix kernel becomes more efficient compared - // to the matrix-vector kernel - // ne20 = n_used_experts - // ne21 = n_rows - const int dst_rows = ne20*ne21; - const int dst_rows_min = n_as; - const int dst_rows_max = (ctx->device.maxThreadgroupMemoryLength - 32 - 8192)/4; - - // max size of the rowids array in the kernel shared buffer - GGML_ASSERT(dst_rows <= dst_rows_max); - - // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs - // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel - // !!! - // TODO: for now, always use mat-vec kernels until we figure out how to improve the - // indirect matrix multiplication - // !!! - if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && - ne00 % 32 == 0 && ne00 >= 64 && - dst_rows > dst_rows_min) { - - // some Metal matrix data types require aligned pointers - // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5) - switch (src0->type) { - case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break; - case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break; - default: break; - } - - id pipeline = nil; - - switch (src0->type) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32 ].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32 ].pipeline; break; - case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32 ].pipeline; break; - case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32 ].pipeline; break; - case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32 ].pipeline; break; - case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32 ].pipeline; break; - case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32 ].pipeline; break; - case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32 ].pipeline; break; - case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32 ].pipeline; break; - case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32 ].pipeline; break; - case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32 ].pipeline; break; - case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32 ].pipeline; break; - case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break; - case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break; - case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline; break; - case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32 ].pipeline; break; - case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32 ].pipeline; break; - case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32 ].pipeline; break; - case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32 ].pipeline; break; - case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break; - case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32 ].pipeline; break; - default: GGML_ABORT("MUL_MAT_ID not implemented"); - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; - [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4]; - [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5]; - [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:8]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:9]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:10]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:18]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19]; - - [encoder setThreadgroupMemoryLength:GGML_PAD(8192 + dst_rows*4/*sizeof(ushort2)*/, 16) atIndex:0]; - - [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, n_as) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; - } else { - int nth0 = 32; - int nth1 = 1; - int nrows = 1; - //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); - - id pipeline = nil; - - // use custom matrix x vector kernel - switch (src0t) { - case GGML_TYPE_F32: - { - GGML_ASSERT(src1t == GGML_TYPE_F32); - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32].pipeline; - } break; - case GGML_TYPE_F16: - { - GGML_ASSERT(src1t == GGML_TYPE_F32); - nth0 = 32; - nth1 = 1; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline; - } break; - case GGML_TYPE_Q4_0: - { - nth0 = 8; - nth1 = 8; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32].pipeline; - } break; - case GGML_TYPE_Q4_1: - { - nth0 = 8; - nth1 = 8; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32].pipeline; - } break; - case GGML_TYPE_Q5_0: - { - nth0 = 8; - nth1 = 8; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32].pipeline; - } break; - case GGML_TYPE_Q5_1: - { - nth0 = 8; - nth1 = 8; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32].pipeline; - } break; - case GGML_TYPE_Q8_0: - { - nth0 = 8; - nth1 = 8; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline; - } break; - case GGML_TYPE_Q2_K: - { - nth0 = 2; - nth1 = 32; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32].pipeline; - } break; - case GGML_TYPE_Q3_K: - { - nth0 = 2; - nth1 = 32; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32].pipeline; - } break; - case GGML_TYPE_Q4_K: - { - nth0 = 4; //1; - nth1 = 8; //32; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32].pipeline; - } break; - case GGML_TYPE_Q5_K: - { - nth0 = 2; - nth1 = 32; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32].pipeline; - } break; - case GGML_TYPE_Q6_K: - { - nth0 = 2; - nth1 = 32; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32].pipeline; - } break; - case GGML_TYPE_IQ2_XXS: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32].pipeline; - } break; - case GGML_TYPE_IQ2_XS: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32].pipeline; - } break; - case GGML_TYPE_IQ3_XXS: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32].pipeline; - } break; - case GGML_TYPE_IQ3_S: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32].pipeline; - } break; - case GGML_TYPE_IQ2_S: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32].pipeline; - } break; - case GGML_TYPE_IQ1_S: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32].pipeline; - } break; - case GGML_TYPE_IQ1_M: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32].pipeline; - } break; - case GGML_TYPE_IQ4_NL: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline; - } break; - case GGML_TYPE_IQ4_XS: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline; - } break; - default: - { - GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t); - GGML_ABORT("not implemented"); - } - }; - - if (ggml_is_quantized(src0t)) { - GGML_ASSERT(ne00 >= nth0*nth1); - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; - [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4]; - [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5]; - [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:8]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:9]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:10]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:11]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:12]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:13]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:14]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:15]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:16]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:17]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:18]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:19]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:20]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:21]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:22]; - - const int64_t _ne1 = 1; - const int tgz = dst_rows; - - if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 || - src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K || - src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) { - const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128; - [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) { - const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4; - [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) { - const int mem_size = 32*sizeof(float); - [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_Q4_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_Q3_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_Q5_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_Q6_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } else { - const int64_t ny = (_ne1 + nrows - 1)/nrows; // = _ne1 - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - } - } break; - case GGML_OP_GET_ROWS: - { - id pipeline = nil; - - switch (src0->type) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F32 ].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F16 ].pipeline; break; - case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0 ].pipeline; break; - case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1 ].pipeline; break; - case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0 ].pipeline; break; - case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1 ].pipeline; break; - case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0 ].pipeline; break; - case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K ].pipeline; break; - case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K ].pipeline; break; - case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K ].pipeline; break; - case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K ].pipeline; break; - case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K ].pipeline; break; - case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS].pipeline; break; - case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break; - case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS].pipeline; break; - case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S ].pipeline; break; - case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S ].pipeline; break; - case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S ].pipeline; break; - case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M ].pipeline; break; - case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline; break; - case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS ].pipeline; break; - case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break; - default: GGML_ABORT("not implemented"); - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5]; - [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6]; - [encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7]; - [encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:10]; - - [encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)]; - } break; - case GGML_OP_RMS_NORM: - { - GGML_ASSERT(ne00 % 4 == 0); - GGML_ASSERT(ggml_is_contiguous_1(src0)); - - float eps; - memcpy(&eps, dst->op_params, sizeof(float)); - - int nth = 32; // SIMD width - - while (nth < ne00/4 && nth < 1024) { - nth *= 2; - } - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; - [encoder setBytes:&eps length:sizeof( float) atIndex:4]; - [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; - - const int64_t nrows = ggml_nrows(src0); - - [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_GROUP_NORM: - { - GGML_ASSERT(ne00 % 4 == 0); - GGML_ASSERT(ggml_is_contiguous(src0)); - - float eps; - memcpy(&eps, dst->op_params + 1, sizeof(float)); - - const int32_t n_groups = ((int32_t *) dst->op_params)[0]; - - int nth = 32; // SIMD width - - //while (nth < ne00/4 && nth < 1024) { - // nth *= 2; - //} - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; - [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:5]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:6]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:7]; - [encoder setBytes:&n_groups length:sizeof( int32_t) atIndex:8]; - [encoder setBytes:&eps length:sizeof( float) atIndex:9]; - [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; - - [encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_NORM: - { - GGML_ASSERT(ggml_is_contiguous_1(src0)); - - float eps; - memcpy(&eps, dst->op_params, sizeof(float)); - - const int nth = MIN(256, ne00); - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_NORM].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; - [encoder setBytes:&eps length:sizeof( float) atIndex:4]; - [encoder setThreadgroupMemoryLength:GGML_PAD(nth*sizeof(float), 16) atIndex:0]; - - const int64_t nrows = ggml_nrows(src0); - - [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_ROPE: - { - GGML_ASSERT(ne10 == ne02); - - const int nth = MIN(1024, ne00); - - const int n_past = ((int32_t *) dst->op_params)[0]; - const int n_dims = ((int32_t *) dst->op_params)[1]; - const int mode = ((int32_t *) dst->op_params)[2]; - // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal - const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; - - float freq_base; - float freq_scale; - float ext_factor; - float attn_factor; - float beta_fast; - float beta_slow; - - memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); - memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); - memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); - memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); - memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); - memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); - - const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; - - id pipeline = nil; - - if (!is_neox) { - switch (src0->type) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break; - default: GGML_ABORT("fatal error"); - }; - } else { - switch (src0->type) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break; - default: GGML_ABORT("fatal error"); - }; - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - if (id_src2 != nil) { - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; - } else { - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:2]; - } - [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:4]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6]; - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7]; - [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:8]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:10]; - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:11]; - [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:12]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:14]; - [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:15]; - [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:16]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:17]; - [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:18]; - [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:19]; - [encoder setBytes:&n_past length:sizeof( int) atIndex:20]; - [encoder setBytes:&n_dims length:sizeof( int) atIndex:21]; - [encoder setBytes:&n_ctx_orig length:sizeof( int) atIndex:22]; - [encoder setBytes:&freq_base length:sizeof( float) atIndex:23]; - [encoder setBytes:&freq_scale length:sizeof( float) atIndex:24]; - [encoder setBytes:&ext_factor length:sizeof( float) atIndex:25]; - [encoder setBytes:&attn_factor length:sizeof( float) atIndex:26]; - [encoder setBytes:&beta_fast length:sizeof( float) atIndex:27]; - [encoder setBytes:&beta_slow length:sizeof( float) atIndex:28]; - - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_IM2COL: - { - GGML_ASSERT(src0->type == GGML_TYPE_F16); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32); - - const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; - const int32_t s1 = ((const int32_t *)(dst->op_params))[1]; - const int32_t p0 = ((const int32_t *)(dst->op_params))[2]; - const int32_t p1 = ((const int32_t *)(dst->op_params))[3]; - const int32_t d0 = ((const int32_t *)(dst->op_params))[4]; - const int32_t d1 = ((const int32_t *)(dst->op_params))[5]; - - const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1; - - const int32_t N = src1->ne[is_2D ? 3 : 2]; - const int32_t IC = src1->ne[is_2D ? 2 : 1]; - const int32_t IH = is_2D ? src1->ne[1] : 1; - const int32_t IW = src1->ne[0]; - - const int32_t KH = is_2D ? src0->ne[1] : 1; - const int32_t KW = src0->ne[0]; - - const int32_t OH = is_2D ? dst->ne[2] : 1; - const int32_t OW = dst->ne[1]; - - const int32_t CHW = IC * KH * KW; - - const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4; - const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4; - - id pipeline = nil; - - switch (dst->type) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline; break; - default: GGML_ABORT("fatal error"); - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2]; - [encoder setBytes:&ofs1 length:sizeof( int32_t) atIndex:3]; - [encoder setBytes:&IW length:sizeof( int32_t) atIndex:4]; - [encoder setBytes:&IH length:sizeof( int32_t) atIndex:5]; - [encoder setBytes:&CHW length:sizeof( int32_t) atIndex:6]; - [encoder setBytes:&s0 length:sizeof( int32_t) atIndex:7]; - [encoder setBytes:&s1 length:sizeof( int32_t) atIndex:8]; - [encoder setBytes:&p0 length:sizeof( int32_t) atIndex:9]; - [encoder setBytes:&p1 length:sizeof( int32_t) atIndex:10]; - [encoder setBytes:&d0 length:sizeof( int32_t) atIndex:11]; - [encoder setBytes:&d1 length:sizeof( int32_t) atIndex:12]; - - [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)]; - } break; - case GGML_OP_UPSCALE: - { - GGML_ASSERT(src0->type == GGML_TYPE_F32); - - const float sf0 = (float)ne0/src0->ne[0]; - const float sf1 = (float)ne1/src0->ne[1]; - const float sf2 = (float)ne2/src0->ne[2]; - const float sf3 = (float)ne3/src0->ne[3]; - - const id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12]; - [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15]; - [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16]; - [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17]; - [encoder setBytes:&sf0 length:sizeof(sf0) atIndex:18]; - [encoder setBytes:&sf1 length:sizeof(sf1) atIndex:19]; - [encoder setBytes:&sf2 length:sizeof(sf2) atIndex:20]; - [encoder setBytes:&sf3 length:sizeof(sf3) atIndex:21]; - - const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0); - - [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_PAD: - { - GGML_ASSERT(src0->type == GGML_TYPE_F32); - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12]; - [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15]; - [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16]; - [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17]; - - const int nth = MIN(1024, ne0); - - [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_ARANGE: - { - GGML_ASSERT(dst->type == GGML_TYPE_F32); - - float start; - float step; - - memcpy(&start, ((int32_t *) dst->op_params) + 0, sizeof(float)); - memcpy(&step, ((int32_t *) dst->op_params) + 2, sizeof(float)); - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARANGE_F32].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:0]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:1]; - [encoder setBytes:&start length:sizeof(start) atIndex:2]; - [encoder setBytes:&step length:sizeof(step) atIndex:3]; - - const int nth = MIN(1024, ne0); - - [encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_TIMESTEP_EMBEDDING: - { - GGML_ASSERT(src0->type == GGML_TYPE_F32); - - const int dim = dst->op_params[0]; - const int max_period = dst->op_params[1]; - - const int half = dim / 2; - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:2]; - [encoder setBytes:&dim length:sizeof(dim) atIndex:3]; - [encoder setBytes:&max_period length:sizeof(max_period) atIndex:4]; - - const int nth = MIN(1024, half); - - [encoder dispatchThreadgroups:MTLSizeMake(ne00, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_ARGSORT: - { - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_I32); - - const int nrows = ggml_nrows(src0); - - enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0]; - - // bitonic sort requires the number of elements to be power of 2 - int64_t ne00_padded = 1; - while (ne00_padded < ne00) { - ne00_padded *= 2; - } - - // Metal kernels require the buffer size to be multiple of 16 bytes - // https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength - const int mem_size = GGML_PAD(ne00_padded*sizeof(int32_t), 16); - - id pipeline = nil; - - switch (order) { - case GGML_SORT_ORDER_ASC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC].pipeline; break; - case GGML_SORT_ORDER_DESC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC].pipeline; break; - default: GGML_ABORT("fatal error"); - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&ne00_padded length:sizeof( int64_t) atIndex:3]; - [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; - - [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00_padded, 1, 1)]; - } break; - case GGML_OP_LEAKY_RELU: - { - GGML_ASSERT(src0->type == GGML_TYPE_F32); - - float slope; - memcpy(&slope, dst->op_params, sizeof(float)); - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&slope length:sizeof(slope) atIndex:2]; - - const int64_t n = ggml_nelements(dst); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_OP_FLASH_ATTN_EXT: - { - GGML_ASSERT(ne00 % 4 == 0); - GGML_ASSERT(ne11 % 32 == 0); - - GGML_ASSERT(src0->type == GGML_TYPE_F32); - - GGML_ASSERT(ggml_are_same_shape (src1, src2)); - - struct ggml_tensor * src3 = gf->nodes[i]->src[3]; - - size_t offs_src3 = 0; - - id id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil; - - GGML_ASSERT(!src3 || src3->type == GGML_TYPE_F16); - GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) && - "the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big"); - - const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30); - //const int64_t ne31 = src3 ? src3->ne[1] : 0; - const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32); - const int64_t ne33 = src3 ? src3->ne[3] : 0; GGML_UNUSED(ne33); - - const uint64_t nb30 = src3 ? src3->nb[0] : 0; GGML_UNUSED(nb30); - const uint64_t nb31 = src3 ? src3->nb[1] : 0; - const uint64_t nb32 = src3 ? src3->nb[2] : 0; GGML_UNUSED(nb32); - const uint64_t nb33 = src3 ? src3->nb[3] : 0; GGML_UNUSED(nb33); - - const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t); - - float scale; - float max_bias; - float logit_softcap; - memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale)); - memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias)); - memcpy(&logit_softcap, ((int32_t *) dst->op_params) + 2, sizeof(logit_softcap)); - - if (logit_softcap != 0.0f) { - scale /= logit_softcap; - } - - const uint32_t n_head = src0->ne[2]; - const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); - - const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); - const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - - id pipeline = nil; - - bool use_vec_kernel = false; - - if (ne01 >= 4 || (ne00%128 != 0)) { - switch (ne00) { - case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break; - case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break; - case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break; - case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break; - case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break; - //case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break; - default: - { - GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00); - GGML_METAL_LOG_ERROR("add template specialization for this size\n"); - GGML_ABORT("add template specialization for this size"); - } - } - } else { - use_vec_kernel = true; - - switch (ne00) { - case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break; - //case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break; - default: - { - GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00); - GGML_METAL_LOG_ERROR("add template specialization for this size\n"); - GGML_ABORT("add template specialization for this size"); - } - } - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; - if (id_src3) { - [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; - } else { - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:3]; - } - [encoder setBuffer:id_dst offset:offs_dst atIndex:4]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6]; - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10]; - [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:11]; - [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:12]; - [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14]; - [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15]; - [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16]; - [encoder setBytes:&nb21 length:sizeof(uint64_t) atIndex:17]; - [encoder setBytes:&nb22 length:sizeof(uint64_t) atIndex:18]; - [encoder setBytes:&nb23 length:sizeof(uint64_t) atIndex:19]; - [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:20]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:21]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:22]; - [encoder setBytes:&scale length:sizeof( float) atIndex:23]; - [encoder setBytes:&max_bias length:sizeof( float) atIndex:24]; - [encoder setBytes:&m0 length:sizeof(m0) atIndex:25]; - [encoder setBytes:&m1 length:sizeof(m1) atIndex:26]; - [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:27]; - [encoder setBytes:&logit_softcap length:sizeof(logit_softcap) atIndex:28]; - - if (!use_vec_kernel) { - // half8x8 kernel - const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! - const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! - - GGML_ASSERT(nqptg <= 32); - GGML_ASSERT(nqptg % 8 == 0); - GGML_ASSERT(ncpsg % 32 == 0); - - int64_t nsgmax = 2; - - while (true) { - const size_t smem = nqptg*(ne00 + 2*nsgmax*(ncpsg + nqptg))*(sizeof(float)/2); - if (smem > ctx->device.maxThreadgroupMemoryLength) { - break; - } - nsgmax *= 2; - } - nsgmax /= 2; - - // simdgroups per threadgroup (a.k.a. warps) - const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4; - - const size_t smem = nqptg*(ne00 + 2*nsg*(ncpsg + nqptg))*(sizeof(float)/2); - - //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength); - GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); - - [encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0]; - - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; - } else { - // half1x4 kernel - const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !! - const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! - - GGML_ASSERT(nqptg <= 32); - GGML_ASSERT(nqptg % 1 == 0); - GGML_ASSERT(ncpsg % 32 == 0); - - // simdgroups per threadgroup (a.k.a. warps) - const int64_t nsgt = MAX(2, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)); - - int64_t nsg = 1; - while (nsg <= nsgt) { - nsg *= 2; - } - nsg /= 2; - - const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2); - - //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength); - GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); - [encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0]; - - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; - } - } break; - case GGML_OP_DUP: - case GGML_OP_CPY: - case GGML_OP_CONT: - { - GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0); - - int nth = MIN(1024, ne00/ggml_blck_size(src0->type)); - - id pipeline = nil; - - switch (src0t) { - case GGML_TYPE_F32: - { - GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0); - - switch (dstt) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break; - case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break; - case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break; - case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break; - case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0].pipeline; break; - case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1].pipeline; break; - case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL].pipeline; break; - default: GGML_ABORT("not implemented"); - }; - } break; - case GGML_TYPE_F16: - { - switch (dstt) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F32].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline; break; - default: GGML_ABORT("not implemented"); - }; - } break; - default: GGML_ABORT("not implemented"); - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; - [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; - [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; - [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; - - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - default: - { - GGML_METAL_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op)); - GGML_ABORT("fatal error"); + } break; + case GGML_OP_MUL_MAT_ID: + { + const int n_as = src0->ne[2]; + + // src2 = ids + const enum ggml_type src2t = src2->type; GGML_UNUSED(src2t); + + GGML_ASSERT(src2t == GGML_TYPE_I32); + + GGML_ASSERT(!ggml_is_transposed(src0)); + GGML_ASSERT(!ggml_is_transposed(src1)); + + GGML_ASSERT(src1t == GGML_TYPE_F32); + + // find the break-even point where the matrix-matrix kernel becomes more efficient compared + // to the matrix-vector kernel + // ne20 = n_used_experts + // ne21 = n_rows + const int dst_rows = ne20*ne21; + const int dst_rows_min = n_as; + const int dst_rows_max = (ctx->device.maxThreadgroupMemoryLength - 32 - 8192)/4; + + // max size of the rowids array in the kernel shared buffer + GGML_ASSERT(dst_rows <= dst_rows_max); + + // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs + // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel + // !!! + // TODO: for now, always use mat-vec kernels until we figure out how to improve the + // indirect matrix multiplication + // !!! + if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && + ne00 % 32 == 0 && ne00 >= 64 && + dst_rows > dst_rows_min) { + + // some Metal matrix data types require aligned pointers + // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5) + switch (src0->type) { + case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break; + case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break; + default: break; } - } - if (should_capture) { - [encoder popDebugGroup]; + id pipeline = nil; + + switch (src0->type) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32 ].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32 ].pipeline; break; + case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32 ].pipeline; break; + case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32 ].pipeline; break; + case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32 ].pipeline; break; + case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32 ].pipeline; break; + case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32 ].pipeline; break; + case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32 ].pipeline; break; + case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32 ].pipeline; break; + case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32 ].pipeline; break; + case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32 ].pipeline; break; + case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32 ].pipeline; break; + case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break; + case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break; + case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline; break; + case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32 ].pipeline; break; + case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32 ].pipeline; break; + case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32 ].pipeline; break; + case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32 ].pipeline; break; + case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break; + case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32 ].pipeline; break; + default: GGML_ABORT("MUL_MAT_ID not implemented"); + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; + [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4]; + [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5]; + [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:8]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:9]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:10]; + [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12]; + [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:18]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19]; + + [encoder setThreadgroupMemoryLength:GGML_PAD(8192 + dst_rows*4/*sizeof(ushort2)*/, 16) atIndex:0]; + + [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, n_as) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; + } else { + int nth0 = 32; + int nth1 = 1; + int nrows = 1; + //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); + + id pipeline = nil; + + // use custom matrix x vector kernel + switch (src0t) { + case GGML_TYPE_F32: + { + GGML_ASSERT(src1t == GGML_TYPE_F32); + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32].pipeline; + } break; + case GGML_TYPE_F16: + { + GGML_ASSERT(src1t == GGML_TYPE_F32); + nth0 = 32; + nth1 = 1; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline; + } break; + case GGML_TYPE_Q4_0: + { + nth0 = 8; + nth1 = 8; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32].pipeline; + } break; + case GGML_TYPE_Q4_1: + { + nth0 = 8; + nth1 = 8; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32].pipeline; + } break; + case GGML_TYPE_Q5_0: + { + nth0 = 8; + nth1 = 8; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32].pipeline; + } break; + case GGML_TYPE_Q5_1: + { + nth0 = 8; + nth1 = 8; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32].pipeline; + } break; + case GGML_TYPE_Q8_0: + { + nth0 = 8; + nth1 = 8; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline; + } break; + case GGML_TYPE_Q2_K: + { + nth0 = 2; + nth1 = 32; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32].pipeline; + } break; + case GGML_TYPE_Q3_K: + { + nth0 = 2; + nth1 = 32; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32].pipeline; + } break; + case GGML_TYPE_Q4_K: + { + nth0 = 4; //1; + nth1 = 8; //32; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32].pipeline; + } break; + case GGML_TYPE_Q5_K: + { + nth0 = 2; + nth1 = 32; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32].pipeline; + } break; + case GGML_TYPE_Q6_K: + { + nth0 = 2; + nth1 = 32; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32].pipeline; + } break; + case GGML_TYPE_IQ2_XXS: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32].pipeline; + } break; + case GGML_TYPE_IQ2_XS: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32].pipeline; + } break; + case GGML_TYPE_IQ3_XXS: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32].pipeline; + } break; + case GGML_TYPE_IQ3_S: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32].pipeline; + } break; + case GGML_TYPE_IQ2_S: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32].pipeline; + } break; + case GGML_TYPE_IQ1_S: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32].pipeline; + } break; + case GGML_TYPE_IQ1_M: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32].pipeline; + } break; + case GGML_TYPE_IQ4_NL: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline; + } break; + case GGML_TYPE_IQ4_XS: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline; + } break; + default: + { + GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t); + GGML_ABORT("not implemented"); + } + }; + + if (ggml_is_quantized(src0t)) { + GGML_ASSERT(ne00 >= nth0*nth1); + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; + [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4]; + [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5]; + [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:8]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:9]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:10]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:11]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:12]; + [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:13]; + [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:14]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:15]; + [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:16]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:17]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:18]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:19]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:20]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:21]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:22]; + + const int64_t _ne1 = 1; + const int tgz = dst_rows; + + if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 || + src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K || + src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) { + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) { + const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128; + [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) { + const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4; + [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) { + const int mem_size = 32*sizeof(float); + [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + else if (src0t == GGML_TYPE_Q4_K) { + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + else if (src0t == GGML_TYPE_Q3_K) { + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + else if (src0t == GGML_TYPE_Q5_K) { + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + else if (src0t == GGML_TYPE_Q6_K) { + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } else { + const int64_t ny = (_ne1 + nrows - 1)/nrows; // = _ne1 + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + } + } break; + case GGML_OP_GET_ROWS: + { + id pipeline = nil; + + switch (src0->type) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F32 ].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F16 ].pipeline; break; + case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0 ].pipeline; break; + case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1 ].pipeline; break; + case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0 ].pipeline; break; + case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1 ].pipeline; break; + case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0 ].pipeline; break; + case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K ].pipeline; break; + case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K ].pipeline; break; + case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K ].pipeline; break; + case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K ].pipeline; break; + case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K ].pipeline; break; + case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS].pipeline; break; + case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break; + case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS].pipeline; break; + case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S ].pipeline; break; + case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S ].pipeline; break; + case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S ].pipeline; break; + case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M ].pipeline; break; + case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline; break; + case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS ].pipeline; break; + case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break; + default: GGML_ABORT("not implemented"); + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5]; + [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6]; + [encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7]; + [encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8]; + [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:9]; + [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:10]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)]; + } break; + case GGML_OP_RMS_NORM: + { + GGML_ASSERT(ne00 % 4 == 0); + GGML_ASSERT(ggml_is_contiguous_1(src0)); + + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + + int nth = 32; // SIMD width + + while (nth < ne00/4 && nth < 1024) { + nth *= 2; + } + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; + [encoder setBytes:&eps length:sizeof( float) atIndex:4]; + [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; + + const int64_t nrows = ggml_nrows(src0); + + [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_GROUP_NORM: + { + GGML_ASSERT(ne00 % 4 == 0); + GGML_ASSERT(ggml_is_contiguous(src0)); + + float eps; + memcpy(&eps, dst->op_params + 1, sizeof(float)); + + const int32_t n_groups = ((const int32_t *) dst->op_params)[0]; + + int nth = 32; // SIMD width + + //while (nth < ne00/4 && nth < 1024) { + // nth *= 2; + //} + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; + [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:5]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:6]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:7]; + [encoder setBytes:&n_groups length:sizeof( int32_t) atIndex:8]; + [encoder setBytes:&eps length:sizeof( float) atIndex:9]; + [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; + + [encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_NORM: + { + GGML_ASSERT(ggml_is_contiguous_1(src0)); + + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + + const int nth = MIN(256, ne00); + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_NORM].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; + [encoder setBytes:&eps length:sizeof( float) atIndex:4]; + [encoder setThreadgroupMemoryLength:GGML_PAD(nth*sizeof(float), 16) atIndex:0]; + + const int64_t nrows = ggml_nrows(src0); + + [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_ROPE: + { + GGML_ASSERT(ne10 == ne02); + + const int nth = MIN(1024, ne00); + + const int n_past = ((const int32_t *) dst->op_params)[0]; + const int n_dims = ((const int32_t *) dst->op_params)[1]; + const int mode = ((const int32_t *) dst->op_params)[2]; + // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal + const int n_ctx_orig = ((const int32_t *) dst->op_params)[4]; + + float freq_base; + float freq_scale; + float ext_factor; + float attn_factor; + float beta_fast; + float beta_slow; + + memcpy(&freq_base, (const int32_t *) dst->op_params + 5, sizeof(float)); + memcpy(&freq_scale, (const int32_t *) dst->op_params + 6, sizeof(float)); + memcpy(&ext_factor, (const int32_t *) dst->op_params + 7, sizeof(float)); + memcpy(&attn_factor, (const int32_t *) dst->op_params + 8, sizeof(float)); + memcpy(&beta_fast, (const int32_t *) dst->op_params + 9, sizeof(float)); + memcpy(&beta_slow, (const int32_t *) dst->op_params + 10, sizeof(float)); + + const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; + + id pipeline = nil; + + if (!is_neox) { + switch (src0->type) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break; + default: GGML_ABORT("fatal error"); + }; + } else { + switch (src0->type) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break; + default: GGML_ABORT("fatal error"); + }; + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + if (id_src2 != nil) { + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; + } else { + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:2]; + } + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:4]; + [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5]; + [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6]; + [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7]; + [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:8]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:9]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:10]; + [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:11]; + [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:12]; + [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:13]; + [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:14]; + [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:15]; + [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:16]; + [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:17]; + [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:18]; + [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:19]; + [encoder setBytes:&n_past length:sizeof( int) atIndex:20]; + [encoder setBytes:&n_dims length:sizeof( int) atIndex:21]; + [encoder setBytes:&n_ctx_orig length:sizeof( int) atIndex:22]; + [encoder setBytes:&freq_base length:sizeof( float) atIndex:23]; + [encoder setBytes:&freq_scale length:sizeof( float) atIndex:24]; + [encoder setBytes:&ext_factor length:sizeof( float) atIndex:25]; + [encoder setBytes:&attn_factor length:sizeof( float) atIndex:26]; + [encoder setBytes:&beta_fast length:sizeof( float) atIndex:27]; + [encoder setBytes:&beta_slow length:sizeof( float) atIndex:28]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_IM2COL: + { + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32); + + const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; + const int32_t s1 = ((const int32_t *)(dst->op_params))[1]; + const int32_t p0 = ((const int32_t *)(dst->op_params))[2]; + const int32_t p1 = ((const int32_t *)(dst->op_params))[3]; + const int32_t d0 = ((const int32_t *)(dst->op_params))[4]; + const int32_t d1 = ((const int32_t *)(dst->op_params))[5]; + + const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1; + + const int32_t N = src1->ne[is_2D ? 3 : 2]; + const int32_t IC = src1->ne[is_2D ? 2 : 1]; + const int32_t IH = is_2D ? src1->ne[1] : 1; + const int32_t IW = src1->ne[0]; + + const int32_t KH = is_2D ? src0->ne[1] : 1; + const int32_t KW = src0->ne[0]; + + const int32_t OH = is_2D ? dst->ne[2] : 1; + const int32_t OW = dst->ne[1]; + + const int32_t CHW = IC * KH * KW; + + const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4; + const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4; + + id pipeline = nil; + + switch (dst->type) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline; break; + default: GGML_ABORT("fatal error"); + }; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2]; + [encoder setBytes:&ofs1 length:sizeof( int32_t) atIndex:3]; + [encoder setBytes:&IW length:sizeof( int32_t) atIndex:4]; + [encoder setBytes:&IH length:sizeof( int32_t) atIndex:5]; + [encoder setBytes:&CHW length:sizeof( int32_t) atIndex:6]; + [encoder setBytes:&s0 length:sizeof( int32_t) atIndex:7]; + [encoder setBytes:&s1 length:sizeof( int32_t) atIndex:8]; + [encoder setBytes:&p0 length:sizeof( int32_t) atIndex:9]; + [encoder setBytes:&p1 length:sizeof( int32_t) atIndex:10]; + [encoder setBytes:&d0 length:sizeof( int32_t) atIndex:11]; + [encoder setBytes:&d1 length:sizeof( int32_t) atIndex:12]; + + [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)]; + } break; + case GGML_OP_UPSCALE: + { + GGML_ASSERT(src0->type == GGML_TYPE_F32); + + const float sf0 = (float)ne0/src0->ne[0]; + const float sf1 = (float)ne1/src0->ne[1]; + const float sf2 = (float)ne2/src0->ne[2]; + const float sf3 = (float)ne3/src0->ne[3]; + + const id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; + [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11]; + [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12]; + [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13]; + [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15]; + [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16]; + [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17]; + [encoder setBytes:&sf0 length:sizeof(sf0) atIndex:18]; + [encoder setBytes:&sf1 length:sizeof(sf1) atIndex:19]; + [encoder setBytes:&sf2 length:sizeof(sf2) atIndex:20]; + [encoder setBytes:&sf3 length:sizeof(sf3) atIndex:21]; + + const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0); + + [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_PAD: + { + GGML_ASSERT(src0->type == GGML_TYPE_F32); + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; + [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11]; + [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12]; + [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13]; + [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15]; + [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16]; + [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17]; + + const int nth = MIN(1024, ne0); + + [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_ARANGE: + { + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + float start; + float step; + + memcpy(&start, ((const int32_t *) dst->op_params) + 0, sizeof(float)); + memcpy(&step, ((const int32_t *) dst->op_params) + 2, sizeof(float)); + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARANGE_F32].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:0]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:1]; + [encoder setBytes:&start length:sizeof(start) atIndex:2]; + [encoder setBytes:&step length:sizeof(step) atIndex:3]; + + const int nth = MIN(1024, ne0); + + [encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_TIMESTEP_EMBEDDING: + { + GGML_ASSERT(src0->type == GGML_TYPE_F32); + + const int dim = dst->op_params[0]; + const int max_period = dst->op_params[1]; + + const int half = dim / 2; + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:2]; + [encoder setBytes:&dim length:sizeof(dim) atIndex:3]; + [encoder setBytes:&max_period length:sizeof(max_period) atIndex:4]; + + const int nth = MIN(1024, half); + + [encoder dispatchThreadgroups:MTLSizeMake(ne00, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_ARGSORT: + { + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_I32); + + const int nrows = ggml_nrows(src0); + + enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0]; + + // bitonic sort requires the number of elements to be power of 2 + int64_t ne00_padded = 1; + while (ne00_padded < ne00) { + ne00_padded *= 2; + } + + // Metal kernels require the buffer size to be multiple of 16 bytes + // https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength + const int mem_size = GGML_PAD(ne00_padded*sizeof(int32_t), 16); + + id pipeline = nil; + + switch (order) { + case GGML_SORT_ORDER_ASC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC].pipeline; break; + case GGML_SORT_ORDER_DESC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC].pipeline; break; + default: GGML_ABORT("fatal error"); + }; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + [encoder setBytes:&ne00_padded length:sizeof( int64_t) atIndex:3]; + [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; + + [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00_padded, 1, 1)]; + } break; + case GGML_OP_LEAKY_RELU: + { + GGML_ASSERT(src0->type == GGML_TYPE_F32); + + float slope; + memcpy(&slope, dst->op_params, sizeof(float)); + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&slope length:sizeof(slope) atIndex:2]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_FLASH_ATTN_EXT: + { + GGML_ASSERT(ne00 % 4 == 0); + GGML_ASSERT(ne11 % 32 == 0); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + + GGML_ASSERT(ggml_are_same_shape (src1, src2)); + + struct ggml_tensor * src3 = node->src[3]; + + size_t offs_src3 = 0; + + id id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil; + + GGML_ASSERT(!src3 || src3->type == GGML_TYPE_F16); + GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) && + "the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big"); + + const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30); + //const int64_t ne31 = src3 ? src3->ne[1] : 0; + const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32); + const int64_t ne33 = src3 ? src3->ne[3] : 0; GGML_UNUSED(ne33); + + const uint64_t nb30 = src3 ? src3->nb[0] : 0; GGML_UNUSED(nb30); + const uint64_t nb31 = src3 ? src3->nb[1] : 0; + const uint64_t nb32 = src3 ? src3->nb[2] : 0; GGML_UNUSED(nb32); + const uint64_t nb33 = src3 ? src3->nb[3] : 0; GGML_UNUSED(nb33); + + const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t); + + float scale; + float max_bias; + float logit_softcap; + memcpy(&scale, ((const int32_t *) dst->op_params) + 0, sizeof(scale)); + memcpy(&max_bias, ((const int32_t *) dst->op_params) + 1, sizeof(max_bias)); + memcpy(&logit_softcap, ((const int32_t *) dst->op_params) + 2, sizeof(logit_softcap)); + + if (logit_softcap != 0.0f) { + scale /= logit_softcap; + } + + const uint32_t n_head = src0->ne[2]; + const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + id pipeline = nil; + + bool use_vec_kernel = false; + + if (ne01 >= 4 || (ne00%128 != 0)) { + switch (ne00) { + case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break; + case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break; + case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break; + case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break; + case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break; + //case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break; + default: + { + GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00); + GGML_METAL_LOG_ERROR("add template specialization for this size\n"); + GGML_ABORT("add template specialization for this size"); + } + } + } else { + use_vec_kernel = true; + + switch (ne00) { + case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break; + //case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break; + default: + { + GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00); + GGML_METAL_LOG_ERROR("add template specialization for this size\n"); + GGML_ABORT("add template specialization for this size"); + } + } + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; + if (id_src3) { + [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; + } else { + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:3]; + } + [encoder setBuffer:id_dst offset:offs_dst atIndex:4]; + [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5]; + [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6]; + [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9]; + [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10]; + [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:11]; + [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:12]; + [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:13]; + [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14]; + [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15]; + [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16]; + [encoder setBytes:&nb21 length:sizeof(uint64_t) atIndex:17]; + [encoder setBytes:&nb22 length:sizeof(uint64_t) atIndex:18]; + [encoder setBytes:&nb23 length:sizeof(uint64_t) atIndex:19]; + [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:20]; + [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:21]; + [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:22]; + [encoder setBytes:&scale length:sizeof( float) atIndex:23]; + [encoder setBytes:&max_bias length:sizeof( float) atIndex:24]; + [encoder setBytes:&m0 length:sizeof(m0) atIndex:25]; + [encoder setBytes:&m1 length:sizeof(m1) atIndex:26]; + [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:27]; + [encoder setBytes:&logit_softcap length:sizeof(logit_softcap) atIndex:28]; + + if (!use_vec_kernel) { + // half8x8 kernel + const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! + const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! + + GGML_ASSERT(nqptg <= 32); + GGML_ASSERT(nqptg % 8 == 0); + GGML_ASSERT(ncpsg % 32 == 0); + + int64_t nsgmax = 2; + + while (true) { + const size_t smem = nqptg*(ne00 + 2*nsgmax*(ncpsg + nqptg))*(sizeof(float)/2); + if (smem > ctx->device.maxThreadgroupMemoryLength) { + break; + } + nsgmax *= 2; + } + nsgmax /= 2; + + // simdgroups per threadgroup (a.k.a. warps) + const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4; + + const size_t smem = nqptg*(ne00 + 2*nsg*(ncpsg + nqptg))*(sizeof(float)/2); + + //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength); + GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); + + [encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0]; + + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; + } else { + // half1x4 kernel + const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !! + const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! + + GGML_ASSERT(nqptg <= 32); + GGML_ASSERT(nqptg % 1 == 0); + GGML_ASSERT(ncpsg % 32 == 0); + + // simdgroups per threadgroup (a.k.a. warps) + const int64_t nsgt = MAX(2, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)); + + int64_t nsg = 1; + while (nsg <= nsgt) { + nsg *= 2; + } + nsg /= 2; + + const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2); + + //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength); + GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); + [encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0]; + + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; + } + } break; + case GGML_OP_DUP: + case GGML_OP_CPY: + case GGML_OP_CONT: + { + GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0); + + int nth = MIN(1024, ne00/ggml_blck_size(src0->type)); + + id pipeline = nil; + + switch (src0t) { + case GGML_TYPE_F32: + { + GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0); + + switch (dstt) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break; + case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break; + case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break; + case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break; + case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0].pipeline; break; + case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1].pipeline; break; + case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL].pipeline; break; + default: GGML_ABORT("not implemented"); + }; + } break; + case GGML_TYPE_F16: + { + switch (dstt) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline; break; + default: GGML_ABORT("not implemented"); + }; + } break; + default: GGML_ABORT("not implemented"); + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; + [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; + [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; + [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; + [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; + [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; + [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; + [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; + [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; + [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; + [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; + [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + default: + { + GGML_METAL_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op)); + GGML_ABORT("fatal error"); + } + } +} + +static enum ggml_status ggml_metal_graph_compute( + struct ggml_backend_metal_context * ctx, + struct ggml_cgraph * gf) { + // number of nodes encoded by the main thread (empirically determined) + const int n_main = 128; + + // number of threads in addition to the main thread + const int n_cb = ctx->n_cb; + + // submit the ggml compute graph to the GPU by creating command buffers and encoding the ops in them + // the first n_nodes_0 are encoded and submitted for processing directly by the calling thread + // while these nodes are processing, we start n_cb threads to enqueue the rest of the nodes + // each thread creates it's own command buffer and enqueues the ops in parallel + // + // tests on M1 Pro and M2 Ultra using LLaMA models, show that optimal values for n_cb are 1 or 2 + + @autoreleasepool { + ctx->gf = gf; + + ctx->n_nodes_0 = MIN(n_main, gf->n_nodes); + ctx->n_nodes_1 = gf->n_nodes - ctx->n_nodes_0; + + ctx->n_nodes_per_cb = (ctx->n_nodes_1 + ctx->n_cb - 1) / ctx->n_cb; + + const bool should_capture = ctx->capture_next_compute; + if (should_capture) { + ctx->capture_next_compute = false; + + if (!ctx->capture_started) { + // create capture scope + ctx->capture_scope = [[MTLCaptureManager sharedCaptureManager] newCaptureScopeWithDevice:ctx->device]; + + MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new]; + descriptor.captureObject = ctx->capture_scope; + descriptor.destination = MTLCaptureDestinationGPUTraceDocument; + descriptor.outputURL = [NSURL fileURLWithPath:[NSString stringWithFormat:@"/tmp/perf-metal.gputrace"]]; + + NSError * error = nil; + if (![[MTLCaptureManager sharedCaptureManager] startCaptureWithDescriptor:descriptor error:&error]) { + GGML_METAL_LOG_ERROR("%s: error: unable to start capture '%s'\n", __func__, [[error localizedDescription] UTF8String]); + GGML_ABORT("capture failed"); + } else { + [ctx->capture_scope beginScope]; + ctx->capture_started = true; + } } } - [encoder endEncoding]; + // TODO: how to avoid this allocation? I tried initializing it in ggml_backend_metal_set_n_cb but it crashes. + ctx->encode_async = ^(size_t iter) { + const int cb_idx = iter; + const int n_cb_l = ctx->n_cb; - if (cb_idx < 2 || ctx->abort_callback == NULL) { - [command_buffer commit]; - } - }); + const int n_nodes_0 = ctx->n_nodes_0; + const int n_nodes_1 = ctx->n_nodes_1; - // Wait for completion and check status of each command buffer - // needed to detect if the device ran out-of-memory for example (#1881) + const int n_nodes_per_cb = ctx->n_nodes_per_cb; - for (int i = 0; i < n_cb; ++i) { - id command_buffer = command_buffers[i]; - [command_buffer waitUntilCompleted]; + id command_buffer = ctx->command_buffers[cb_idx]; + id encoder = [command_buffer computeCommandEncoderWithDescriptor: ctx->edesc]; - MTLCommandBufferStatus status = [command_buffer status]; - if (status != MTLCommandBufferStatusCompleted) { - GGML_METAL_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status); - if (status == MTLCommandBufferStatusError) { - GGML_METAL_LOG_INFO("error: %s\n", [[command_buffer error].localizedDescription UTF8String]); + int node_start = 0; + int node_end = n_nodes_0; + + if (cb_idx < n_cb_l) { + node_start = n_nodes_0 + ( (cb_idx + 0) * n_nodes_per_cb); + node_end = n_nodes_0 + (MIN((cb_idx == n_cb_l - 1) ? n_nodes_1 : (cb_idx + 1) * n_nodes_per_cb, n_nodes_1)); } - return GGML_STATUS_FAILED; + for (int idx = node_start; idx < node_end; ++idx) { + if (should_capture) { + [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(gf, idx)) encoding:NSUTF8StringEncoding]]; + } + + ggml_metal_encode_node(ctx, idx, encoder); + + if (should_capture) { + [encoder popDebugGroup]; + } + } + + [encoder endEncoding]; + + if (cb_idx < 2 || ctx->abort_callback == NULL) { + [command_buffer commit]; + } + }; + + // the main thread commits the first few commands immediately + // command_buffer[n_cb] + { + id command_buffer = [ctx->queue commandBufferWithUnretainedReferences]; + ctx->command_buffers[n_cb] = command_buffer; + + [command_buffer enqueue]; + ctx->encode_async(n_cb); } - id next_buffer = (i + 1 < n_cb ? command_buffers[i + 1] : nil); - if (!next_buffer) { - continue; + // prepare the rest of the command buffers asynchronously + // command_buffer[0.. n_cb) + for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) { + id command_buffer = [ctx->queue commandBufferWithUnretainedReferences]; + ctx->command_buffers[cb_idx] = command_buffer; + + // always enqueue the first two command buffers + // enqueue all of the command buffers if we don't need to abort + if (cb_idx < 2 || ctx->abort_callback == NULL) { + [command_buffer enqueue]; + } } - bool next_queued = ([next_buffer status] != MTLCommandBufferStatusNotEnqueued); - if (next_queued) { - continue; + dispatch_apply(n_cb, ctx->d_queue, ctx->encode_async); + + // wait for completion and check status of each command buffer + // needed to detect if the device ran out-of-memory for example (#1881) + { + id command_buffer = ctx->command_buffers[n_cb]; + [command_buffer waitUntilCompleted]; + + MTLCommandBufferStatus status = [command_buffer status]; + if (status != MTLCommandBufferStatusCompleted) { + GGML_METAL_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, n_cb, status); + if (status == MTLCommandBufferStatusError) { + GGML_METAL_LOG_INFO("error: %s\n", [[command_buffer error].localizedDescription UTF8String]); + } + + return GGML_STATUS_FAILED; + } } - if (ctx->abort_callback && ctx->abort_callback(ctx->abort_callback_data)) { - GGML_METAL_LOG_INFO("%s: command buffer %d aborted", __func__, i); - return GGML_STATUS_ABORTED; + for (int i = 0; i < n_cb; ++i) { + id command_buffer = ctx->command_buffers[i]; + [command_buffer waitUntilCompleted]; + + MTLCommandBufferStatus status = [command_buffer status]; + if (status != MTLCommandBufferStatusCompleted) { + GGML_METAL_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status); + if (status == MTLCommandBufferStatusError) { + GGML_METAL_LOG_INFO("error: %s\n", [[command_buffer error].localizedDescription UTF8String]); + } + + return GGML_STATUS_FAILED; + } + + id next_buffer = (i + 1 < n_cb ? ctx->command_buffers[i + 1] : nil); + if (!next_buffer) { + continue; + } + + const bool next_queued = ([next_buffer status] != MTLCommandBufferStatusNotEnqueued); + if (next_queued) { + continue; + } + + if (ctx->abort_callback && ctx->abort_callback(ctx->abort_callback_data)) { + GGML_METAL_LOG_INFO("%s: command buffer %d aborted", __func__, i); + return GGML_STATUS_ABORTED; + } + + [next_buffer commit]; } - [next_buffer commit]; + if (!should_capture && ctx->capture_started) { + [ctx->capture_scope endScope]; + [[MTLCaptureManager sharedCaptureManager] stopCapture]; + } } - if (should_capture) { - [[MTLCaptureManager sharedCaptureManager] stopCapture]; - } - - } return GGML_STATUS_SUCCESS; } @@ -3405,6 +3504,25 @@ GGML_CALL static bool ggml_backend_metal_supports_buft(ggml_backend_t backend, g UNUSED(backend); } +static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) { + GGML_ASSERT(ggml_backend_is_metal(backend)); + + struct ggml_backend_metal_context * ctx = (struct ggml_backend_metal_context *)backend->context; + + if (ctx->n_cb != n_cb) { + ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_COMMAND_BUFFERS); + + if (ctx->n_cb > 2) { + GGML_METAL_LOG_WARN("%s: n_cb = %d, using n_cb > 2 is not recommended and can degrade the performance in some cases\n", __func__, n_cb); + } + } + + // TODO: setting encode_async here causes crash during the next ggml_metal_graph_compute call. why? + //ctx->encode_async = ^(size_t iter) { + // ... + //}; +} + static struct ggml_backend_i ggml_backend_metal_i = { /* .get_name = */ ggml_backend_metal_name, /* .free = */ ggml_backend_metal_free, @@ -3439,35 +3557,29 @@ static ggml_guid_t ggml_backend_metal_guid(void) { } ggml_backend_t ggml_backend_metal_init(void) { - struct ggml_backend_metal_context * ctx = ggml_metal_init(GGML_DEFAULT_N_THREADS); + struct ggml_backend_metal_context * ctx = ggml_metal_init(); if (ctx == NULL) { GGML_METAL_LOG_ERROR("%s: error: failed to allocate context\n", __func__); return NULL; } - ggml_backend_t metal_backend = malloc(sizeof(struct ggml_backend)); + ggml_backend_t backend = malloc(sizeof(struct ggml_backend)); - *metal_backend = (struct ggml_backend) { + *backend = (struct ggml_backend) { /* .guid = */ ggml_backend_metal_guid(), /* .interface = */ ggml_backend_metal_i, /* .context = */ ctx, }; - return metal_backend; + ggml_backend_metal_set_n_cb(backend, 1); + + return backend; } bool ggml_backend_is_metal(ggml_backend_t backend) { return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_metal_guid()); } -void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) { - GGML_ASSERT(ggml_backend_is_metal(backend)); - - struct ggml_backend_metal_context * ctx = (struct ggml_backend_metal_context *)backend->context; - - ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS); -} - void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data) { GGML_ASSERT(ggml_backend_is_metal(backend)); @@ -3489,7 +3601,7 @@ void ggml_backend_metal_capture_next_compute(ggml_backend_t backend) { GGML_ASSERT(ggml_backend_is_metal(backend)); struct ggml_backend_metal_context * ctx = (struct ggml_backend_metal_context *)backend->context; - ctx->should_capture_next_compute = true; + ctx->capture_next_compute = true; } GGML_CALL ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data); // silence warning diff --git a/src/llama.cpp b/src/llama.cpp index d1d27d21e..4c0a1bb61 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -17025,12 +17025,6 @@ static void llama_graph_compute( ggml_cgraph * gf, int n_threads, ggml_threadpool * threadpool) { -#ifdef GGML_USE_METAL - if (ggml_backend_is_metal(lctx.backend_metal)) { - ggml_backend_metal_set_n_cb(lctx.backend_metal, n_threads); - } -#endif - if (lctx.backend_cpu != nullptr) { ggml_backend_cpu_set_n_threads(lctx.backend_cpu, n_threads); ggml_backend_cpu_set_threadpool(lctx.backend_cpu, threadpool); From 7254cdf7e8d6312840509ca8d766358c687c6266 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sun, 29 Sep 2024 23:18:02 +0200 Subject: [PATCH 29/33] ggml: fix gradient allocation logic (ggml/966) * ggml: fix gradient allocation logic * gradient allocation in ggml_build_backward_expand * fixup * fix test-backend-ops grad * suggestions by slaren * fix test1.c * fix legacy opt API * fix test-grad0 * remove keep arg --- ggml/include/ggml.h | 40 +- ggml/src/ggml.c | 1464 +++++++++++------------------------- tests/test-backend-ops.cpp | 37 +- tests/test-grad0.cpp | 14 +- 4 files changed, 490 insertions(+), 1065 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index f46d4a8a6..a8a74bee1 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -577,10 +577,10 @@ extern "C" { // this tensor... enum ggml_tensor_flag { - GGML_TENSOR_FLAG_INPUT = 1, // ...is an input for the GGML compute graph - GGML_TENSOR_FLAG_OUTPUT = 2, // ...is an output for the GGML compute graph - GGML_TENSOR_FLAG_PARAM = 4, // ...contains trainable parameters - GGML_TENSOR_FLAG_LOSS = 8, // ...defines loss for numerical optimization (multiple loss tensors add up) + GGML_TENSOR_FLAG_INPUT = 1, // ...is an input for the GGML compute graph + GGML_TENSOR_FLAG_OUTPUT = 2, // ...is an output for the GGML compute graph + GGML_TENSOR_FLAG_PARAM = 4, // ...contains trainable parameters + GGML_TENSOR_FLAG_LOSS = 8, // ...defines loss for numerical optimization (multiple loss tensors add up) }; // n-dimensional tensor @@ -1410,14 +1410,14 @@ extern "C" { // supports 3D: a->ne[2] == b->ne[1] GGML_API struct ggml_tensor * ggml_get_rows( struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b); + struct ggml_tensor * a, // data + struct ggml_tensor * b); // row indices GGML_API struct ggml_tensor * ggml_get_rows_back( struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - struct ggml_tensor * c); + struct ggml_tensor * a, // gradients of ggml_get_rows result + struct ggml_tensor * b, // row indices + struct ggml_tensor * c); // data for ggml_get_rows, only used for its shape GGML_API struct ggml_tensor * ggml_diag( struct ggml_context * ctx, @@ -1568,9 +1568,9 @@ extern "C" { // a - dy GGML_API struct ggml_tensor * ggml_rope_back( struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - struct ggml_tensor * c, + struct ggml_tensor * a, // gradients of ggml_rope result + struct ggml_tensor * b, // positions + struct ggml_tensor * c, // freq factors int n_dims, int mode, int n_ctx_orig, @@ -2036,15 +2036,15 @@ extern "C" { // loss function GGML_API struct ggml_tensor * ggml_cross_entropy_loss( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b); + struct ggml_context * ctx, + struct ggml_tensor * a, // logits + struct ggml_tensor * b); // labels GGML_API struct ggml_tensor * ggml_cross_entropy_loss_back( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - struct ggml_tensor * c); + struct ggml_context * ctx, + struct ggml_tensor * a, // logits + struct ggml_tensor * b, // labels + struct ggml_tensor * c); // gradients of cross_entropy_loss result // AdamW optimizer step // Paper: https://arxiv.org/pdf/1711.05101v3.pdf @@ -2066,7 +2066,7 @@ extern "C" { GGML_API void ggml_set_loss(struct ggml_tensor * tensor); GGML_API void ggml_build_forward_expand (struct ggml_cgraph * cgraph, struct ggml_tensor * tensor); - GGML_API void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool accumulate, bool keep); + GGML_API void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool accumulate); GGML_API void ggml_build_opt_adamw( struct ggml_context * ctx, diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 81b651c6a..fa8d6c25a 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -4725,18 +4725,11 @@ struct ggml_tensor * ggml_get_tensor(struct ggml_context * ctx, const char * nam static struct ggml_tensor * ggml_dup_impl( struct ggml_context * ctx, - struct ggml_tensor * a, - bool inplace) { - bool is_node = false; - - if (!inplace && (a->grad)) { - is_node = true; - } - + struct ggml_tensor * a, + bool inplace) { struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - result->op = GGML_OP_DUP; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_DUP; result->src[0] = a; return result; @@ -4744,13 +4737,13 @@ static struct ggml_tensor * ggml_dup_impl( struct ggml_tensor * ggml_dup( struct ggml_context * ctx, - struct ggml_tensor * a) { + struct ggml_tensor * a) { return ggml_dup_impl(ctx, a, false); } struct ggml_tensor * ggml_dup_inplace( struct ggml_context * ctx, - struct ggml_tensor * a) { + struct ggml_tensor * a) { return ggml_dup_impl(ctx, a, true); } @@ -4758,21 +4751,14 @@ struct ggml_tensor * ggml_dup_inplace( static struct ggml_tensor * ggml_add_impl( struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - bool inplace) { + struct ggml_tensor * a, + struct ggml_tensor * b, + bool inplace) { GGML_ASSERT(ggml_can_repeat(b, a)); - bool is_node = false; - - if (!inplace && (a->grad || b->grad)) { - is_node = true; - } - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - result->op = GGML_OP_ADD; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_ADD; result->src[0] = a; result->src[1] = b; @@ -4781,15 +4767,15 @@ static struct ggml_tensor * ggml_add_impl( struct ggml_tensor * ggml_add( struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b) { + struct ggml_tensor * a, + struct ggml_tensor * b) { return ggml_add_impl(ctx, a, b, false); } struct ggml_tensor * ggml_add_inplace( struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b) { + struct ggml_tensor * a, + struct ggml_tensor * b) { return ggml_add_impl(ctx, a, b, true); } @@ -4797,9 +4783,9 @@ struct ggml_tensor * ggml_add_inplace( static struct ggml_tensor * ggml_add_cast_impl( struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - enum ggml_type type) { + struct ggml_tensor * a, + struct ggml_tensor * b, + enum ggml_type type) { // TODO: support less-strict constraint // GGML_ASSERT(ggml_can_repeat(b, a)); GGML_ASSERT(ggml_can_repeat_rows(b, a)); @@ -4809,18 +4795,9 @@ static struct ggml_tensor * ggml_add_cast_impl( a->type == GGML_TYPE_F16 || a->type == GGML_TYPE_BF16); - bool is_node = false; - - if (a->grad || b->grad) { - // TODO: support backward pass for broadcasting - GGML_ASSERT(ggml_are_same_shape(a, b)); - is_node = true; - } - struct ggml_tensor * result = ggml_new_tensor(ctx, type, GGML_MAX_DIMS, a->ne); - result->op = GGML_OP_ADD; - result->grad = is_node ? ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, a->ne) : NULL; + result->op = GGML_OP_ADD; result->src[0] = a; result->src[1] = b; @@ -4829,9 +4806,9 @@ static struct ggml_tensor * ggml_add_cast_impl( struct ggml_tensor * ggml_add_cast( struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - enum ggml_type type) { + struct ggml_tensor * a, + struct ggml_tensor * b, + enum ggml_type type) { return ggml_add_cast_impl(ctx, a, b, type); } @@ -4839,22 +4816,15 @@ struct ggml_tensor * ggml_add_cast( static struct ggml_tensor * ggml_add1_impl( struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - bool inplace) { + struct ggml_tensor * a, + struct ggml_tensor * b, + bool inplace) { GGML_ASSERT(ggml_is_scalar(b)); GGML_ASSERT(ggml_is_padded_1d(a)); - bool is_node = false; - - if (a->grad || b->grad) { - is_node = true; - } - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - result->op = GGML_OP_ADD1; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_ADD1; result->src[0] = a; result->src[1] = b; @@ -4863,15 +4833,15 @@ static struct ggml_tensor * ggml_add1_impl( struct ggml_tensor * ggml_add1( struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b) { + struct ggml_tensor * a, + struct ggml_tensor * b) { return ggml_add1_impl(ctx, a, b, false); } struct ggml_tensor * ggml_add1_inplace( struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b) { + struct ggml_tensor * a, + struct ggml_tensor * b) { return ggml_add1_impl(ctx, a, b, true); } @@ -4879,31 +4849,24 @@ struct ggml_tensor * ggml_add1_inplace( static struct ggml_tensor * ggml_acc_impl( struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - size_t nb1, - size_t nb2, - size_t nb3, - size_t offset, - bool inplace) { + struct ggml_tensor * a, + struct ggml_tensor * b, + size_t nb1, + size_t nb2, + size_t nb3, + size_t offset, + bool inplace) { GGML_ASSERT(ggml_nelements(b) <= ggml_nelements(a)); GGML_ASSERT(ggml_is_contiguous(a)); GGML_ASSERT(a->type == GGML_TYPE_F32); GGML_ASSERT(b->type == GGML_TYPE_F32); - bool is_node = false; - - if (!inplace && (a->grad || b->grad)) { - is_node = true; - } - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); int32_t params[] = { nb1, nb2, nb3, offset, inplace ? 1 : 0 }; ggml_set_op_params(result, params, sizeof(params)); - result->op = GGML_OP_ACC; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_ACC; result->src[0] = a; result->src[1] = b; @@ -4912,23 +4875,23 @@ static struct ggml_tensor * ggml_acc_impl( struct ggml_tensor * ggml_acc( struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - size_t nb1, - size_t nb2, - size_t nb3, - size_t offset) { + struct ggml_tensor * a, + struct ggml_tensor * b, + size_t nb1, + size_t nb2, + size_t nb3, + size_t offset) { return ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false); } struct ggml_tensor * ggml_acc_inplace( struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - size_t nb1, - size_t nb2, - size_t nb3, - size_t offset) { + struct ggml_tensor * a, + struct ggml_tensor * b, + size_t nb1, + size_t nb2, + size_t nb3, + size_t offset) { return ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, true); } @@ -4936,23 +4899,14 @@ struct ggml_tensor * ggml_acc_inplace( static struct ggml_tensor * ggml_sub_impl( struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - bool inplace) { + struct ggml_tensor * a, + struct ggml_tensor * b, + bool inplace) { GGML_ASSERT(ggml_can_repeat(b, a)); - bool is_node = false; - - if (!inplace && (a->grad || b->grad)) { - // TODO: support backward pass for broadcasting - GGML_ASSERT(ggml_are_same_shape(a, b)); - is_node = true; - } - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - result->op = GGML_OP_SUB; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_SUB; result->src[0] = a; result->src[1] = b; @@ -4961,15 +4915,15 @@ static struct ggml_tensor * ggml_sub_impl( struct ggml_tensor * ggml_sub( struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b) { + struct ggml_tensor * a, + struct ggml_tensor * b) { return ggml_sub_impl(ctx, a, b, false); } struct ggml_tensor * ggml_sub_inplace( struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b) { + struct ggml_tensor * a, + struct ggml_tensor * b) { return ggml_sub_impl(ctx, a, b, true); } @@ -4977,27 +4931,14 @@ struct ggml_tensor * ggml_sub_inplace( static struct ggml_tensor * ggml_mul_impl( struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - bool inplace) { + struct ggml_tensor * a, + struct ggml_tensor * b, + bool inplace) { GGML_ASSERT(ggml_can_repeat(b, a)); - bool is_node = false; - - if (!inplace && (a->grad || b->grad)) { - // TODO: support backward pass for broadcasting - GGML_ASSERT(ggml_are_same_shape(a, b)); - is_node = true; - } - - if (inplace) { - GGML_ASSERT(!is_node); - } - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - result->op = GGML_OP_MUL; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_MUL; result->src[0] = a; result->src[1] = b; @@ -5022,25 +4963,14 @@ struct ggml_tensor * ggml_mul_inplace( static struct ggml_tensor * ggml_div_impl( struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - bool inplace) { + struct ggml_tensor * a, + struct ggml_tensor * b, + bool inplace) { GGML_ASSERT(ggml_can_repeat(b, a)); - bool is_node = false; - - if (!inplace && (a->grad || b->grad)) { - is_node = true; - } - - if (inplace) { - GGML_ASSERT(!is_node); - } - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - result->op = GGML_OP_DIV; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_DIV; result->src[0] = a; result->src[1] = b; @@ -5065,18 +4995,11 @@ struct ggml_tensor * ggml_div_inplace( static struct ggml_tensor * ggml_sqr_impl( struct ggml_context * ctx, - struct ggml_tensor * a, - bool inplace) { - bool is_node = false; - - if (!inplace && (a->grad)) { - is_node = true; - } - + struct ggml_tensor * a, + bool inplace) { struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - result->op = GGML_OP_SQR; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_SQR; result->src[0] = a; return result; @@ -5098,18 +5021,11 @@ struct ggml_tensor * ggml_sqr_inplace( static struct ggml_tensor * ggml_sqrt_impl( struct ggml_context * ctx, - struct ggml_tensor * a, - bool inplace) { - bool is_node = false; - - if (!inplace && (a->grad)) { - is_node = true; - } - + struct ggml_tensor * a, + bool inplace) { struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - result->op = GGML_OP_SQRT; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_SQRT; result->src[0] = a; return result; @@ -5132,17 +5048,10 @@ struct ggml_tensor * ggml_sqrt_inplace( static struct ggml_tensor * ggml_log_impl( struct ggml_context * ctx, struct ggml_tensor * a, - bool inplace) { - bool is_node = false; - - if (!inplace && (a->grad)) { - is_node = true; - } - + bool inplace) { struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - result->op = GGML_OP_LOG; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_LOG; result->src[0] = a; return result; @@ -5165,17 +5074,10 @@ struct ggml_tensor * ggml_log_inplace( static struct ggml_tensor * ggml_sin_impl( struct ggml_context * ctx, struct ggml_tensor * a, - bool inplace) { - bool is_node = false; - - if (!inplace && (a->grad)) { - is_node = true; - } - + bool inplace) { struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - result->op = GGML_OP_SIN; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_SIN; result->src[0] = a; return result; @@ -5198,17 +5100,10 @@ struct ggml_tensor * ggml_sin_inplace( static struct ggml_tensor * ggml_cos_impl( struct ggml_context * ctx, struct ggml_tensor * a, - bool inplace) { - bool is_node = false; - - if (!inplace && (a->grad)) { - is_node = true; - } - + bool inplace) { struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - result->op = GGML_OP_COS; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_COS; result->src[0] = a; return result; @@ -5230,17 +5125,10 @@ struct ggml_tensor * ggml_cos_inplace( struct ggml_tensor * ggml_sum( struct ggml_context * ctx, - struct ggml_tensor * a) { - bool is_node = false; - - if (a->grad) { - is_node = true; - } - + struct ggml_tensor * a) { struct ggml_tensor * result = ggml_new_tensor_1d(ctx, a->type, 1); - result->op = GGML_OP_SUM; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_SUM; result->src[0] = a; return result; @@ -5250,13 +5138,7 @@ struct ggml_tensor * ggml_sum( struct ggml_tensor * ggml_sum_rows( struct ggml_context * ctx, - struct ggml_tensor * a) { - bool is_node = false; - - if (a->grad) { - is_node = true; - } - + struct ggml_tensor * a) { int64_t ne[GGML_MAX_DIMS] = { 1 }; for (int i = 1; i < GGML_MAX_DIMS; ++i) { ne[i] = a->ne[i]; @@ -5264,8 +5146,7 @@ struct ggml_tensor * ggml_sum_rows( struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, GGML_MAX_DIMS, ne); - result->op = GGML_OP_SUM_ROWS; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_SUM_ROWS; result->src[0] = a; return result; @@ -5275,19 +5156,11 @@ struct ggml_tensor * ggml_sum_rows( struct ggml_tensor * ggml_mean( struct ggml_context * ctx, - struct ggml_tensor * a) { - bool is_node = false; - - if (a->grad) { - GGML_ABORT("fatal error"); // TODO: implement - is_node = true; - } - + struct ggml_tensor * a) { int64_t ne[4] = { 1, a->ne[1], a->ne[2], a->ne[3] }; struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); - result->op = GGML_OP_MEAN; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_MEAN; result->src[0] = a; return result; @@ -5297,19 +5170,12 @@ struct ggml_tensor * ggml_mean( struct ggml_tensor * ggml_argmax( struct ggml_context * ctx, - struct ggml_tensor * a) { + struct ggml_tensor * a) { GGML_ASSERT(ggml_is_matrix(a)); - bool is_node = false; - - if (a->grad) { - GGML_ABORT("fatal error"); - is_node = true; - } struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, a->ne[1]); - result->op = GGML_OP_ARGMAX; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_ARGMAX; result->src[0] = a; return result; @@ -5319,20 +5185,13 @@ struct ggml_tensor * ggml_argmax( struct ggml_tensor * ggml_repeat( struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b) { + struct ggml_tensor * a, + struct ggml_tensor * b) { GGML_ASSERT(ggml_can_repeat(a, b)); - bool is_node = false; - - if (a->grad) { - is_node = true; - } - struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, GGML_MAX_DIMS, b->ne); - result->op = GGML_OP_REPEAT; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_REPEAT; result->src[0] = a; return result; @@ -5342,24 +5201,13 @@ struct ggml_tensor * ggml_repeat( struct ggml_tensor * ggml_repeat_back( struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b) { + struct ggml_tensor * a, + struct ggml_tensor * b) { GGML_ASSERT(ggml_can_repeat(b, a)); - bool is_node = false; - - if (a->grad) { - is_node = true; - } - - if (ggml_are_same_shape(a, b) && !is_node) { - return a; - } - struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, GGML_MAX_DIMS, b->ne); - result->op = GGML_OP_REPEAT_BACK; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_REPEAT_BACK; result->src[0] = a; return result; @@ -5369,9 +5217,9 @@ struct ggml_tensor * ggml_repeat_back( struct ggml_tensor * ggml_concat( struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - int dim) { + struct ggml_tensor * a, + struct ggml_tensor * b, + int dim) { GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS); int64_t ne[GGML_MAX_DIMS]; @@ -5384,19 +5232,11 @@ struct ggml_tensor * ggml_concat( ne[d] = a->ne[d]; } - bool is_node = false; - - if (a->grad || b->grad) { - GGML_ABORT("fatal error"); // TODO: implement - is_node = true; - } - struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, GGML_MAX_DIMS, ne); ggml_set_op_params_i32(result, 0, dim); - result->op = GGML_OP_CONCAT; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_CONCAT; result->src[0] = a; result->src[1] = b; @@ -5505,20 +5345,14 @@ struct ggml_tensor * ggml_relu_inplace( struct ggml_tensor * ggml_leaky_relu( struct ggml_context * ctx, - struct ggml_tensor * a, float negative_slope, bool inplace) { - bool is_node = false; - - if (!inplace && (a->grad)) { - GGML_ABORT("fatal error"); // TODO: not implemented - is_node = true; - } - + struct ggml_tensor * a, + float negative_slope, + bool inplace) { struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); ggml_set_op_params(result, &negative_slope, sizeof(negative_slope)); - result->op = GGML_OP_LEAKY_RELU; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_LEAKY_RELU; result->src[0] = a; return result; @@ -5586,17 +5420,9 @@ struct ggml_tensor * ggml_silu_back( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b) { - bool is_node = false; - - if (a->grad || b->grad) { - // TODO: implement backward - is_node = true; - } - struct ggml_tensor * result = ggml_dup_tensor(ctx, a); - result->op = GGML_OP_SILU_BACK; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_SILU_BACK; result->src[0] = a; result->src[1] = b; @@ -5604,6 +5430,7 @@ struct ggml_tensor * ggml_silu_back( } // ggml hardswish + struct ggml_tensor * ggml_hardswish( struct ggml_context * ctx, struct ggml_tensor * a) { @@ -5611,6 +5438,7 @@ struct ggml_tensor * ggml_hardswish( } // ggml hardsigmoid + struct ggml_tensor * ggml_hardsigmoid( struct ggml_context * ctx, struct ggml_tensor * a) { @@ -5618,6 +5446,7 @@ struct ggml_tensor * ggml_hardsigmoid( } // ggml exp + struct ggml_tensor * ggml_exp( struct ggml_context * ctx, struct ggml_tensor * a) { @@ -5635,21 +5464,13 @@ struct ggml_tensor * ggml_exp_inplace( static struct ggml_tensor * ggml_norm_impl( struct ggml_context * ctx, struct ggml_tensor * a, - float eps, - bool inplace) { - bool is_node = false; - - if (!inplace && (a->grad)) { - GGML_ABORT("fatal error"); // TODO: implement backward - is_node = true; - } - + float eps, + bool inplace) { struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); ggml_set_op_params(result, &eps, sizeof(eps)); - result->op = GGML_OP_NORM; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_NORM; result->src[0] = a; return result; @@ -5658,14 +5479,14 @@ static struct ggml_tensor * ggml_norm_impl( struct ggml_tensor * ggml_norm( struct ggml_context * ctx, struct ggml_tensor * a, - float eps) { + float eps) { return ggml_norm_impl(ctx, a, eps, false); } struct ggml_tensor * ggml_norm_inplace( struct ggml_context * ctx, struct ggml_tensor * a, - float eps) { + float eps) { return ggml_norm_impl(ctx, a, eps, true); } @@ -5674,20 +5495,13 @@ struct ggml_tensor * ggml_norm_inplace( static struct ggml_tensor * ggml_rms_norm_impl( struct ggml_context * ctx, struct ggml_tensor * a, - float eps, - bool inplace) { - bool is_node = false; - - if (!inplace && (a->grad)) { - is_node = true; - } - + float eps, + bool inplace) { struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); ggml_set_op_params(result, &eps, sizeof(eps)); - result->op = GGML_OP_RMS_NORM; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_RMS_NORM; result->src[0] = a; return result; @@ -5696,14 +5510,14 @@ static struct ggml_tensor * ggml_rms_norm_impl( struct ggml_tensor * ggml_rms_norm( struct ggml_context * ctx, struct ggml_tensor * a, - float eps) { + float eps) { return ggml_rms_norm_impl(ctx, a, eps, false); } struct ggml_tensor * ggml_rms_norm_inplace( struct ggml_context * ctx, struct ggml_tensor * a, - float eps) { + float eps) { return ggml_rms_norm_impl(ctx, a, eps, true); } @@ -5713,20 +5527,12 @@ struct ggml_tensor * ggml_rms_norm_back( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, - float eps) { - bool is_node = false; - - if (a->grad) { - // TODO: implement backward - is_node = true; - } - + float eps) { struct ggml_tensor * result = ggml_dup_tensor(ctx, a); ggml_set_op_params(result, &eps, sizeof(eps)); - result->op = GGML_OP_RMS_NORM_BACK; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_RMS_NORM_BACK; result->src[0] = a; result->src[1] = b; @@ -5736,43 +5542,35 @@ struct ggml_tensor * ggml_rms_norm_back( // ggml_group_norm static struct ggml_tensor * ggml_group_norm_impl( - struct ggml_context * ctx, - struct ggml_tensor * a, - int n_groups, - float eps, - bool inplace) { - - bool is_node = false; - if (!inplace && (a->grad)) { - GGML_ABORT("fatal error"); // TODO: implement backward - is_node = true; - } - + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_groups, + float eps, + bool inplace) { struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); ggml_set_op_params_i32(result, 0, n_groups); ggml_set_op_params_f32(result, 1, eps); - result->op = GGML_OP_GROUP_NORM; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_GROUP_NORM; result->src[0] = a; return result; } struct ggml_tensor * ggml_group_norm( - struct ggml_context * ctx, - struct ggml_tensor * a, - int n_groups, - float eps) { + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_groups, + float eps) { return ggml_group_norm_impl(ctx, a, n_groups, eps, false); } struct ggml_tensor * ggml_group_norm_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a, - int n_groups, - float eps) { + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_groups, + float eps) { return ggml_group_norm_impl(ctx, a, n_groups, eps, true); } @@ -5785,17 +5583,10 @@ struct ggml_tensor * ggml_mul_mat( GGML_ASSERT(ggml_can_mul_mat(a, b)); GGML_ASSERT(!ggml_is_transposed(a)); - bool is_node = false; - - if (a->grad || b->grad) { - is_node = true; - } - const int64_t ne[4] = { a->ne[1], b->ne[1], b->ne[2], b->ne[3] }; struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); - result->op = GGML_OP_MUL_MAT; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_MUL_MAT; result->src[0] = a; result->src[1] = b; @@ -5841,17 +5632,10 @@ struct ggml_tensor * ggml_mul_mat_id( GGML_ASSERT(as->ne[0] == b->ne[0]); // can_mul_mat GGML_ASSERT(ids->ne[0] % b->ne[1] == 0); // can broadcast - bool is_node = false; - - if (as->grad || b->grad) { - is_node = true; - } - const int64_t ne[4] = { as->ne[1], ids->ne[0], b->ne[2], 1 }; struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); - result->op = GGML_OP_MUL_MAT_ID; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_MUL_MAT_ID; result->src[0] = as; result->src[1] = b; result->src[2] = ids; @@ -5868,18 +5652,11 @@ struct ggml_tensor * ggml_out_prod( GGML_ASSERT(ggml_can_out_prod(a, b)); GGML_ASSERT(!ggml_is_transposed(a)); - bool is_node = false; - - if (a->grad || b->grad) { - is_node = true; - } - // a is broadcastable to b for ne[2] and ne[3] -> use b->ne[2] and b->ne[3] const int64_t ne[4] = { a->ne[0], b->ne[0], b->ne[2], b->ne[3] }; struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); - result->op = GGML_OP_OUT_PROD; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_OUT_PROD; result->src[0] = a; result->src[1] = b; @@ -5892,21 +5669,14 @@ static struct ggml_tensor * ggml_scale_impl( struct ggml_context * ctx, struct ggml_tensor * a, float s, - bool inplace) { + bool inplace) { GGML_ASSERT(ggml_is_padded_1d(a)); - bool is_node = false; - - if (a->grad) { - is_node = true; - } - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); ggml_set_op_params(result, &s, sizeof(s)); - result->op = GGML_OP_SCALE; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_SCALE; result->src[0] = a; return result; @@ -5914,15 +5684,15 @@ static struct ggml_tensor * ggml_scale_impl( struct ggml_tensor * ggml_scale( struct ggml_context * ctx, - struct ggml_tensor * a, - float s) { + struct ggml_tensor * a, + float s) { return ggml_scale_impl(ctx, a, s, false); } struct ggml_tensor * ggml_scale_inplace( struct ggml_context * ctx, - struct ggml_tensor * a, - float s) { + struct ggml_tensor * a, + float s) { return ggml_scale_impl(ctx, a, s, true); } @@ -5936,15 +5706,9 @@ static struct ggml_tensor * ggml_set_impl( size_t nb2, size_t nb3, size_t offset, - bool inplace) { + bool inplace) { GGML_ASSERT(ggml_nelements(a) >= ggml_nelements(b)); - bool is_node = false; - - if (a->grad || b->grad) { - is_node = true; - } - // make a view of the destination struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); @@ -5952,8 +5716,7 @@ static struct ggml_tensor * ggml_set_impl( int32_t params[] = { nb1, nb2, nb3, offset, inplace ? 1 : 0 }; ggml_set_op_params(result, params, sizeof(params)); - result->op = GGML_OP_SET; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_SET; result->src[0] = a; result->src[1] = b; @@ -5962,8 +5725,8 @@ static struct ggml_tensor * ggml_set_impl( struct ggml_tensor * ggml_set( struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, + struct ggml_tensor * a, + struct ggml_tensor * b, size_t nb1, size_t nb2, size_t nb3, @@ -5973,8 +5736,8 @@ struct ggml_tensor * ggml_set( struct ggml_tensor * ggml_set_inplace( struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, + struct ggml_tensor * a, + struct ggml_tensor * b, size_t nb1, size_t nb2, size_t nb3, @@ -5984,24 +5747,24 @@ struct ggml_tensor * ggml_set_inplace( struct ggml_tensor * ggml_set_1d( struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, + struct ggml_tensor * a, + struct ggml_tensor * b, size_t offset) { return ggml_set_impl(ctx, a, b, a->nb[1], a->nb[2], a->nb[3], offset, false); } struct ggml_tensor * ggml_set_1d_inplace( struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, + struct ggml_tensor * a, + struct ggml_tensor * b, size_t offset) { return ggml_set_impl(ctx, a, b, a->nb[1], a->nb[2], a->nb[3], offset, true); } struct ggml_tensor * ggml_set_2d( struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, + struct ggml_tensor * a, + struct ggml_tensor * b, size_t nb1, size_t offset) { return ggml_set_impl(ctx, a, b, nb1, a->nb[2], a->nb[3], offset, false); @@ -6009,8 +5772,8 @@ struct ggml_tensor * ggml_set_2d( struct ggml_tensor * ggml_set_2d_inplace( struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, + struct ggml_tensor * a, + struct ggml_tensor * b, size_t nb1, size_t offset) { return ggml_set_impl(ctx, a, b, nb1, a->nb[2], a->nb[3], offset, true); @@ -6024,13 +5787,6 @@ static struct ggml_tensor * ggml_cpy_impl( struct ggml_tensor * b) { GGML_ASSERT(ggml_nelements(a) == ggml_nelements(b)); - bool is_node = false; - - if (a->grad || b->grad) { - // inplace is false and either one have a grad - is_node = true; - } - // make a view of the destination struct ggml_tensor * result = ggml_view_tensor(ctx, b); if (strlen(b->name) > 0) { @@ -6039,8 +5795,7 @@ static struct ggml_tensor * ggml_cpy_impl( ggml_format_name(result, "%s (copy)", a->name); } - result->op = GGML_OP_CPY; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_CPY; result->src[0] = a; result->src[1] = b; @@ -6058,15 +5813,11 @@ struct ggml_tensor * ggml_cast( struct ggml_context * ctx, struct ggml_tensor * a, enum ggml_type type) { - bool is_node = false; - struct ggml_tensor * result = ggml_new_tensor(ctx, type, GGML_MAX_DIMS, a->ne); ggml_format_name(result, "%s (copy)", a->name); - result->op = GGML_OP_CPY; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_CPY; result->src[0] = a; - result->src[1] = result; return result; } @@ -6076,17 +5827,10 @@ struct ggml_tensor * ggml_cast( static struct ggml_tensor * ggml_cont_impl( struct ggml_context * ctx, struct ggml_tensor * a) { - bool is_node = false; - - if (a->grad) { - is_node = true; - } - struct ggml_tensor * result = ggml_dup_tensor(ctx, a); ggml_format_name(result, "%s (cont)", a->name); - result->op = GGML_OP_CONT; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_CONT; result->src[0] = a; return result; @@ -6132,13 +5876,10 @@ struct ggml_tensor * ggml_cont_4d( int64_t ne3) { GGML_ASSERT(ggml_nelements(a) == (ne0*ne1*ne2*ne3)); - bool is_node = false; - struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3); ggml_format_name(result, "%s (cont)", a->name); - result->op = GGML_OP_CONT; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_CONT; result->src[0] = a; return result; @@ -6154,22 +5895,10 @@ struct ggml_tensor * ggml_reshape( // as only the shape of b is relevant, and not its memory layout, b is allowed to be non contiguous. GGML_ASSERT(ggml_nelements(a) == ggml_nelements(b)); - bool is_node = false; - - if (a->grad) { - is_node = true; - } - - if (b->grad) { - // gradient propagation is not supported - //GGML_ABORT("fatal error"); - } - struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, GGML_MAX_DIMS, b->ne, a, 0); ggml_format_name(result, "%s (reshaped)", a->name); - result->op = GGML_OP_RESHAPE; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_RESHAPE; result->src[0] = a; return result; @@ -6182,18 +5911,11 @@ struct ggml_tensor * ggml_reshape_1d( GGML_ASSERT(ggml_is_contiguous(a)); GGML_ASSERT(ggml_nelements(a) == ne0); - bool is_node = false; - - if (a->grad) { - is_node = true; - } - const int64_t ne[1] = { ne0 }; struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 1, ne, a, 0); ggml_format_name(result, "%s (reshaped)", a->name); - result->op = GGML_OP_RESHAPE; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_RESHAPE; result->src[0] = a; return result; @@ -6207,18 +5929,11 @@ struct ggml_tensor * ggml_reshape_2d( GGML_ASSERT(ggml_is_contiguous(a)); GGML_ASSERT(ggml_nelements(a) == ne0*ne1); - bool is_node = false; - - if (a->grad) { - is_node = true; - } - const int64_t ne[2] = { ne0, ne1 }; struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 2, ne, a, 0); ggml_format_name(result, "%s (reshaped)", a->name); - result->op = GGML_OP_RESHAPE; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_RESHAPE; result->src[0] = a; return result; @@ -6233,18 +5948,11 @@ struct ggml_tensor * ggml_reshape_3d( GGML_ASSERT(ggml_is_contiguous(a)); GGML_ASSERT(ggml_nelements(a) == ne0*ne1*ne2); - bool is_node = false; - - if (a->grad) { - is_node = true; - } - const int64_t ne[3] = { ne0, ne1, ne2 }; struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 3, ne, a, 0); ggml_format_name(result, "%s (reshaped)", a->name); - result->op = GGML_OP_RESHAPE; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_RESHAPE; result->src[0] = a; return result; @@ -6260,18 +5968,11 @@ struct ggml_tensor * ggml_reshape_4d( GGML_ASSERT(ggml_is_contiguous(a)); GGML_ASSERT(ggml_nelements(a) == ne0*ne1*ne2*ne3); - bool is_node = false; - - if (a->grad) { - is_node = true; - } - const int64_t ne[4] = { ne0, ne1, ne2, ne3 }; struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 4, ne, a, 0); ggml_format_name(result, "%s (reshaped)", a->name); - result->op = GGML_OP_RESHAPE; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_RESHAPE; result->src[0] = a; return result; @@ -6283,20 +5984,12 @@ static struct ggml_tensor * ggml_view_impl( int n_dims, const int64_t * ne, size_t offset) { - - bool is_node = false; - - if (a->grad) { - is_node = true; - } - struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, n_dims, ne, a, offset); ggml_format_name(result, "%s (view)", a->name); ggml_set_op_params(result, &offset, sizeof(offset)); - result->op = GGML_OP_VIEW; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_VIEW; result->src[0] = a; return result; @@ -6309,7 +6002,6 @@ struct ggml_tensor * ggml_view_1d( struct ggml_tensor * a, int64_t ne0, size_t offset) { - struct ggml_tensor * result = ggml_view_impl(ctx, a, 1, &ne0, offset); return result; @@ -6324,7 +6016,6 @@ struct ggml_tensor * ggml_view_2d( int64_t ne1, size_t nb1, size_t offset) { - const int64_t ne[2] = { ne0, ne1 }; struct ggml_tensor * result = ggml_view_impl(ctx, a, 2, ne, offset); @@ -6347,7 +6038,6 @@ struct ggml_tensor * ggml_view_3d( size_t nb1, size_t nb2, size_t offset) { - const int64_t ne[3] = { ne0, ne1, ne2 }; struct ggml_tensor * result = ggml_view_impl(ctx, a, 3, ne, offset); @@ -6372,7 +6062,6 @@ struct ggml_tensor * ggml_view_4d( size_t nb2, size_t nb3, size_t offset) { - const int64_t ne[4] = { ne0, ne1, ne2, ne3 }; struct ggml_tensor * result = ggml_view_impl(ctx, a, 4, ne, offset); @@ -6405,12 +6094,6 @@ struct ggml_tensor * ggml_permute( GGML_ASSERT(axis1 != axis3); GGML_ASSERT(axis2 != axis3); - bool is_node = false; - - if (a->grad) { - is_node = true; - } - struct ggml_tensor * result = ggml_view_tensor(ctx, a); ggml_format_name(result, "%s (permuted)", a->name); @@ -6437,8 +6120,7 @@ struct ggml_tensor * ggml_permute( result->nb[2] = nb[2]; result->nb[3] = nb[3]; - result->op = GGML_OP_PERMUTE; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_PERMUTE; result->src[0] = a; int32_t params[] = { axis0, axis1, axis2, axis3 }; @@ -6452,12 +6134,6 @@ struct ggml_tensor * ggml_permute( struct ggml_tensor * ggml_transpose( struct ggml_context * ctx, struct ggml_tensor * a) { - bool is_node = false; - - if (a->grad) { - is_node = true; - } - struct ggml_tensor * result = ggml_view_tensor(ctx, a); ggml_format_name(result, "%s (transposed)", a->name); @@ -6467,8 +6143,7 @@ struct ggml_tensor * ggml_transpose( result->nb[0] = a->nb[1]; result->nb[1] = a->nb[0]; - result->op = GGML_OP_TRANSPOSE; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_TRANSPOSE; result->src[0] = a; return result; @@ -6484,12 +6159,6 @@ struct ggml_tensor * ggml_get_rows( GGML_ASSERT(b->ne[3] == 1); GGML_ASSERT(b->type == GGML_TYPE_I32); - bool is_node = false; - - if (a->grad || b->grad) { - is_node = true; - } - // TODO: implement non F32 return enum ggml_type type = GGML_TYPE_F32; if (a->type == GGML_TYPE_I32) { @@ -6497,8 +6166,7 @@ struct ggml_tensor * ggml_get_rows( } struct ggml_tensor * result = ggml_new_tensor_4d(ctx, type, a->ne[0], b->ne[0], b->ne[1], b->ne[2]); - result->op = GGML_OP_GET_ROWS; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_GET_ROWS; result->src[0] = a; result->src[1] = b; @@ -6515,18 +6183,11 @@ struct ggml_tensor * ggml_get_rows_back( GGML_ASSERT(ggml_is_matrix(a) && ggml_is_vector(b) && b->type == GGML_TYPE_I32); GGML_ASSERT(ggml_is_matrix(c) && (a->ne[0] == c->ne[0])); - bool is_node = false; - - if (a->grad || b->grad) { - is_node = true; - } - // TODO: implement non F32 return //struct ggml_tensor * result = ggml_new_tensor_2d(ctx, a->type, a->ne[0], b->ne[0]); struct ggml_tensor * result = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, c->ne[0], c->ne[1]); - result->op = GGML_OP_GET_ROWS_BACK; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_GET_ROWS_BACK; result->src[0] = a; result->src[1] = b; @@ -6539,17 +6200,11 @@ struct ggml_tensor * ggml_diag( struct ggml_context * ctx, struct ggml_tensor * a) { GGML_ASSERT(a->ne[1] == 1); - bool is_node = false; - - if (a->grad) { - is_node = true; - } const int64_t ne[4] = { a->ne[0], a->ne[0], a->ne[2], a->ne[3] }; struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, 4, ne); - result->op = GGML_OP_DIAG; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_DIAG; result->src[0] = a; return result; @@ -6562,19 +6217,12 @@ static struct ggml_tensor * ggml_diag_mask_inf_impl( struct ggml_tensor * a, int n_past, bool inplace) { - bool is_node = false; - - if (a->grad) { - is_node = true; - } - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); int32_t params[] = { n_past }; ggml_set_op_params(result, params, sizeof(params)); - result->op = GGML_OP_DIAG_MASK_INF; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_DIAG_MASK_INF; result->src[0] = a; return result; @@ -6601,19 +6249,12 @@ static struct ggml_tensor * ggml_diag_mask_zero_impl( struct ggml_tensor * a, int n_past, bool inplace) { - bool is_node = false; - - if (a->grad) { - is_node = true; - } - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); int32_t params[] = { n_past }; ggml_set_op_params(result, params, sizeof(params)); - result->op = GGML_OP_DIAG_MASK_ZERO; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_DIAG_MASK_ZERO; result->src[0] = a; return result; @@ -6656,19 +6297,12 @@ static struct ggml_tensor * ggml_soft_max_impl( GGML_ASSERT(mask); } - bool is_node = false; - - if (a->grad) { - is_node = true; - } - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); float params[] = { scale, max_bias }; ggml_set_op_params(result, params, sizeof(params)); - result->op = GGML_OP_SOFT_MAX; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_SOFT_MAX; result->src[0] = a; result->src[1] = mask; @@ -6703,16 +6337,9 @@ static struct ggml_tensor * ggml_soft_max_back_impl( struct ggml_tensor * a, struct ggml_tensor * b, bool inplace) { - bool is_node = false; - - if (a->grad || b->grad) { - is_node = true; // TODO : implement backward pass - } - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - result->op = GGML_OP_SOFT_MAX_BACK; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_SOFT_MAX_BACK; result->src[0] = a; result->src[1] = b; @@ -6761,12 +6388,6 @@ static struct ggml_tensor * ggml_rope_impl( GGML_ASSERT(c->ne[0] >= n_dims / 2); } - bool is_node = false; - - if (a->grad) { - is_node = true; - } - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); int32_t params[11] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig }; @@ -6778,8 +6399,7 @@ static struct ggml_tensor * ggml_rope_impl( memcpy(params + 10, &beta_slow, sizeof(float)); ggml_set_op_params(result, params, sizeof(params)); - result->op = GGML_OP_ROPE; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_ROPE; result->src[0] = a; result->src[1] = b; result->src[2] = c; @@ -6907,13 +6527,6 @@ struct ggml_tensor * ggml_rope_back( GGML_ASSERT(b->type == GGML_TYPE_I32); GGML_ASSERT(a->ne[2] == b->ne[0]); - bool is_node = false; - - if (a->grad) { - GGML_ASSERT(false && "backwards pass not implemented"); - is_node = false; - } - struct ggml_tensor * result = ggml_dup_tensor(ctx, a); int32_t params[11] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig }; @@ -6925,8 +6538,7 @@ struct ggml_tensor * ggml_rope_back( memcpy(params + 10, &beta_slow, sizeof(float)); ggml_set_op_params(result, params, sizeof(params)); - result->op = GGML_OP_ROPE_BACK; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_ROPE_BACK; result->src[0] = a; result->src[1] = b; result->src[2] = c; @@ -6941,21 +6553,13 @@ struct ggml_tensor * ggml_clamp( struct ggml_tensor * a, float min, float max) { - bool is_node = false; - - if (a->grad) { - GGML_ABORT("fatal error"); // TODO: implement backward - is_node = true; - } - // TODO: when implement backward, fix this: struct ggml_tensor * result = ggml_view_tensor(ctx, a); float params[] = { min, max }; ggml_set_op_params(result, params, sizeof(params)); - result->op = GGML_OP_CLAMP; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_CLAMP; result->src[0] = a; return result; @@ -7017,13 +6621,6 @@ GGML_API struct ggml_tensor * ggml_conv_transpose_1d( GGML_ASSERT(p0 == 0); GGML_ASSERT(d0 == 1); - bool is_node = false; - - if (a->grad || b->grad) { - GGML_ABORT("fatal error"); // TODO: implement backward - is_node = true; - } - const int64_t ne[4] = { ggml_calc_conv_transpose_1d_output_size(b->ne[0], a->ne[0], s0, 0 /*p0*/, 1 /*d0*/), a->ne[1], b->ne[2], 1, @@ -7033,8 +6630,7 @@ GGML_API struct ggml_tensor * ggml_conv_transpose_1d( int32_t params[] = { s0, p0, d0 }; ggml_set_op_params(result, params, sizeof(params)); - result->op = GGML_OP_CONV_TRANSPOSE_1D; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_CONV_TRANSPOSE_1D; result->src[0] = a; result->src[1] = b; @@ -7042,17 +6638,17 @@ GGML_API struct ggml_tensor * ggml_conv_transpose_1d( } // ggml_conv_depthwise -struct ggml_tensor * ggml_conv_depthwise_2d( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - int s0, - int s1, - int p0, - int p1, - int d0, - int d1) { +struct ggml_tensor * ggml_conv_depthwise_2d( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int s0, + int s1, + int p0, + int p1, + int d0, + int d1) { struct ggml_tensor * new_a = ggml_reshape_4d(ctx, a, a->ne[0], a->ne[1], 1, a->ne[2] * a->ne[3]); struct ggml_tensor * im2col = ggml_im2col(ctx, new_a, ggml_reshape_4d(ctx, b, b->ne[0], b->ne[1], 1, b->ne[2] * b->ne[3]), @@ -7072,29 +6668,23 @@ struct ggml_tensor * ggml_conv_depthwise_2d( // b: [N, IC, IH, IW] // result: [N, OH, OW, IC*KH*KW] struct ggml_tensor * ggml_im2col( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - int s0, - int s1, - int p0, - int p1, - int d0, - int d1, - bool is_2D, - enum ggml_type dst_type) { - + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int s0, + int s1, + int p0, + int p1, + int d0, + int d1, + bool is_2D, + enum ggml_type dst_type) { if(is_2D) { GGML_ASSERT(a->ne[2] == b->ne[2]); } else { GGML_ASSERT(a->ne[1] == b->ne[1]); GGML_ASSERT(b->ne[3] == 1); } - bool is_node = false; - - if (/*a->grad ||*/ b->grad) { // a is only used for its shape, not its data - is_node = true; - } const int64_t OH = is_2D ? ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1) : 0; const int64_t OW = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0); @@ -7113,8 +6703,7 @@ struct ggml_tensor * ggml_im2col( int32_t params[] = { s0, s1, p0, p1, d0, d1, (is_2D ? 1 : 0) }; ggml_set_op_params(result, params, sizeof(params)); - result->op = GGML_OP_IM2COL; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_IM2COL; result->src[0] = a; result->src[1] = b; @@ -7122,30 +6711,22 @@ struct ggml_tensor * ggml_im2col( } struct ggml_tensor * ggml_im2col_back( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - int64_t * ne, - int s0, - int s1, - int p0, - int p1, - int d0, - int d1, - bool is_2D) { - - bool is_node = false; - - if (/*a->grad ||*/ b->grad) { // a is only used for its shape, not its data - is_node = true; - } - + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int64_t * ne, + int s0, + int s1, + int p0, + int p1, + int d0, + int d1, + bool is_2D) { struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); int32_t params[] = { s0, s1, p0, p1, d0, d1, (is_2D ? 1 : 0) }; ggml_set_op_params(result, params, sizeof(params)); - result->op = GGML_OP_IM2COL_BACK; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_IM2COL_BACK; result->src[0] = a; result->src[1] = b; @@ -7159,12 +6740,12 @@ struct ggml_tensor * ggml_conv_2d( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, - int s0, - int s1, - int p0, - int p1, - int d0, - int d1) { + int s0, + int s1, + int p0, + int p1, + int d0, + int d1) { struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, s1, p0, p1, d0, d1, true, a->type); // [N, OH, OW, IC * KH * KW] struct ggml_tensor * result = @@ -7180,6 +6761,7 @@ struct ggml_tensor * ggml_conv_2d( } // ggml_conv_2d_sk_p0 + struct ggml_tensor * ggml_conv_2d_sk_p0( struct ggml_context * ctx, struct ggml_tensor * a, @@ -7209,13 +6791,6 @@ struct ggml_tensor * ggml_conv_transpose_2d_p0( int stride) { GGML_ASSERT(a->ne[3] == b->ne[2]); - bool is_node = false; - - if (a->grad || b->grad) { - GGML_ABORT("fatal error"); // TODO: implement backward - is_node = true; - } - const int64_t ne[4] = { ggml_calc_conv_transpose_output_size(b->ne[0], a->ne[0], stride, 0 /*p0*/), ggml_calc_conv_transpose_output_size(b->ne[1], a->ne[1], stride, 0 /*p1*/), @@ -7226,8 +6801,7 @@ struct ggml_tensor * ggml_conv_transpose_2d_p0( ggml_set_op_params_i32(result, 0, stride); - result->op = GGML_OP_CONV_TRANSPOSE_2D; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_CONV_TRANSPOSE_2D; result->src[0] = a; result->src[1] = b; @@ -7249,14 +6823,6 @@ struct ggml_tensor * ggml_pool_1d( int k0, int s0, int p0) { - - bool is_node = false; - - if (a->grad) { - GGML_ABORT("fatal error"); // TODO: implement backward - is_node = true; - } - const int64_t ne[4] = { ggml_calc_pool_output_size(a->ne[0], k0, s0, p0), a->ne[1], @@ -7268,8 +6834,7 @@ struct ggml_tensor * ggml_pool_1d( int32_t params[] = { op, k0, s0, p0 }; ggml_set_op_params(result, params, sizeof(params)); - result->op = GGML_OP_POOL_1D; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_POOL_1D; result->src[0] = a; return result; @@ -7287,13 +6852,6 @@ struct ggml_tensor * ggml_pool_2d( int s1, float p0, float p1) { - - bool is_node = false; - - if (a->grad) { - is_node = true; - } - struct ggml_tensor * result; const int64_t ne[4] = { ggml_calc_pool_output_size(a->ne[0], k0, s0, p0), @@ -7306,9 +6864,9 @@ struct ggml_tensor * ggml_pool_2d( int32_t params[] = { op, k0, k1, s0, s1, p0, p1 }; ggml_set_op_params(result, params, sizeof(params)); - result->op = GGML_OP_POOL_2D; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_POOL_2D; result->src[0] = a; + return result; } @@ -7323,100 +6881,74 @@ struct ggml_tensor * ggml_pool_2d_back( int s1, float p0, float p1) { - - bool is_node = false; - - if (a->grad) { - is_node = true; - } - struct ggml_tensor * result; result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, af->ne); int32_t params[] = { op, k0, k1, s0, s1, p0, p1 }; ggml_set_op_params(result, params, sizeof(params)); - result->op = GGML_OP_POOL_2D_BACK; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_POOL_2D_BACK; result->src[0] = a; result->src[1] = af; + return result; } // ggml_upscale static struct ggml_tensor * ggml_upscale_impl( - struct ggml_context * ctx, - struct ggml_tensor * a, - int ne0, - int ne1, - int ne2, - int ne3) { - bool is_node = false; - - if (a->grad) { - GGML_ABORT("fatal error"); // TODO: implement backward - is_node = true; - } - + struct ggml_context * ctx, + struct ggml_tensor * a, + int ne0, + int ne1, + int ne2, + int ne3) { GGML_ASSERT(a->ne[0] <= ne0); GGML_ASSERT(a->ne[1] <= ne1); GGML_ASSERT(a->ne[2] <= ne2); GGML_ASSERT(a->ne[3] <= ne3); - struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, - ne0, - ne1, - ne2, - ne3 - ); + struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3); - result->op = GGML_OP_UPSCALE; - - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_UPSCALE; result->src[0] = a; return result; } struct ggml_tensor * ggml_upscale( - struct ggml_context * ctx, - struct ggml_tensor * a, - int scale_factor) { + struct ggml_context * ctx, + struct ggml_tensor * a, + int scale_factor) { return ggml_upscale_impl(ctx, a, a->ne[0] * scale_factor, a->ne[1] * scale_factor, a->ne[2], a->ne[3]); } struct ggml_tensor * ggml_upscale_ext( - struct ggml_context * ctx, - struct ggml_tensor * a, - int ne0, - int ne1, - int ne2, - int ne3) { + struct ggml_context * ctx, + struct ggml_tensor * a, + int ne0, + int ne1, + int ne2, + int ne3) { return ggml_upscale_impl(ctx, a, ne0, ne1, ne2, ne3); } // ggml_pad struct ggml_tensor * ggml_pad( - struct ggml_context * ctx, - struct ggml_tensor * a, - int p0, int p1, int p2, int p3) { - bool is_node = false; - - if (a->grad) { - GGML_ABORT("fatal error"); // TODO: implement backward - is_node = true; - } - + struct ggml_context * ctx, + struct ggml_tensor * a, + int p0, + int p1, + int p2, + int p3) { struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, a->ne[0] + p0, a->ne[1] + p1, a->ne[2] + p2, a->ne[3] + p3); - result->op = GGML_OP_PAD; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_PAD; result->src[0] = a; return result; @@ -7425,39 +6957,32 @@ struct ggml_tensor * ggml_pad( // ggml_arange struct ggml_tensor * ggml_arange( - struct ggml_context * ctx, - float start, - float stop, - float step) { - + struct ggml_context * ctx, + float start, + float stop, + float step) { GGML_ASSERT(stop > start); const int64_t steps = (int64_t) ceilf((stop - start) / step); struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, steps); - result->op = GGML_OP_ARANGE; ggml_set_op_params_f32(result, 0, start); ggml_set_op_params_f32(result, 1, stop); ggml_set_op_params_f32(result, 2, step); + result->op = GGML_OP_ARANGE; + return result; } // ggml_timestep_embedding struct ggml_tensor * ggml_timestep_embedding( - struct ggml_context * ctx, - struct ggml_tensor * timesteps, - int dim, - int max_period) { - bool is_node = false; - - if (timesteps->grad) { - GGML_ABORT("fatal error"); // TODO: implement backward - is_node = true; - } - + struct ggml_context * ctx, + struct ggml_tensor * timesteps, + int dim, + int max_period) { int actual_dim = dim; if (dim % 2 != 0) { actual_dim = dim + 1; @@ -7465,11 +6990,10 @@ struct ggml_tensor * ggml_timestep_embedding( struct ggml_tensor * result = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, actual_dim, timesteps->ne[0]); - result->op = GGML_OP_TIMESTEP_EMBEDDING; ggml_set_op_params_i32(result, 0, dim); ggml_set_op_params_i32(result, 1, max_period); - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_TIMESTEP_EMBEDDING; result->src[0] = timesteps; return result; @@ -7478,22 +7002,14 @@ struct ggml_tensor * ggml_timestep_embedding( // ggml_argsort struct ggml_tensor * ggml_argsort( - struct ggml_context * ctx, - struct ggml_tensor * a, - enum ggml_sort_order order) { - bool is_node = false; - - if (a->grad) { - GGML_ABORT("fatal error"); // TODO: not implemented - is_node = true; - } - + struct ggml_context * ctx, + struct ggml_tensor * a, + enum ggml_sort_order order) { struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_I32, GGML_MAX_DIMS, a->ne); ggml_set_op_params_i32(result, 0, (int32_t) order); - result->op = GGML_OP_ARGSORT; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_ARGSORT; result->src[0] = a; return result; @@ -7546,10 +7062,6 @@ struct ggml_tensor * ggml_flash_attn_ext( bool is_node = false; - if (q->grad || k->grad || v->grad) { - is_node = true; - } - // permute(0, 2, 1, 3) int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] }; struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); @@ -7676,17 +7188,9 @@ struct ggml_tensor * ggml_ssm_conv( GGML_ASSERT(sx->ne[1] == d_inner); GGML_ASSERT(n_t >= 0); - bool is_node = false; - - if (sx->grad || c->grad) { - GGML_ABORT("fatal error"); // TODO: implement - is_node = true; - } - struct ggml_tensor * result = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_inner, n_t, n_s); - result->op = GGML_OP_SSM_CONV; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_SSM_CONV; result->src[0] = sx; result->src[1] = c; @@ -7730,18 +7234,10 @@ struct ggml_tensor * ggml_ssm_scan( GGML_ASSERT(B->ne[2] == n_seqs); } - bool is_node = false; - - if (s->grad || x->grad || dt->grad || A->grad || B->grad || C->grad) { - GGML_ABORT("fatal error"); // TODO: implement - is_node = true; - } - // concatenated y + ssm_states struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + ggml_nelements(s)); result->op = GGML_OP_SSM_SCAN; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = s; result->src[1] = x; result->src[2] = dt; @@ -7761,13 +7257,6 @@ struct ggml_tensor * ggml_win_part( GGML_ASSERT(a->ne[3] == 1); GGML_ASSERT(a->type == GGML_TYPE_F32); - bool is_node = false; - - if (a->grad) { - GGML_ABORT("fatal error"); // TODO: implement backward - is_node = true; - } - // padding const int px = (w - a->ne[1]%w)%w; const int py = (w - a->ne[2]%w)%w; @@ -7782,8 +7271,7 @@ struct ggml_tensor * ggml_win_part( int32_t params[] = { npx, npy, w }; ggml_set_op_params(result, params, sizeof(params)); - result->op = GGML_OP_WIN_PART; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_WIN_PART; result->src[0] = a; return result; @@ -7799,21 +7287,13 @@ struct ggml_tensor * ggml_win_unpart( int w) { GGML_ASSERT(a->type == GGML_TYPE_F32); - bool is_node = false; - - if (a->grad) { - GGML_ABORT("fatal error"); // TODO: implement backward - is_node = true; - } - const int64_t ne[4] = { a->ne[0], w0, h0, 1, }; struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 3, ne); int32_t params[] = { w }; ggml_set_op_params(result, params, sizeof(params)); - result->op = GGML_OP_WIN_UNPART; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_WIN_UNPART; result->src[0] = a; return result; @@ -7829,18 +7309,10 @@ struct ggml_tensor * ggml_get_rel_pos( GGML_ASSERT(qh == kh); GGML_ASSERT(2*MAX(qh, kh) - 1 == a->ne[1]); - bool is_node = false; - - if (a->grad) { - GGML_ABORT("fatal error"); // TODO: implement backward - is_node = true; - } - const int64_t ne[4] = { a->ne[0], kh, qh, 1, }; struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F16, 3, ne); - result->op = GGML_OP_GET_REL_POS; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_GET_REL_POS; result->src[0] = a; return result; @@ -7864,17 +7336,10 @@ static struct ggml_tensor * ggml_add_rel_pos_impl( GGML_ASSERT(pw->ne[0]*pw->ne[0] == a->ne[0]); GGML_ASSERT(pw->ne[1]*pw->ne[2] == a->ne[1]); - bool is_node = false; - - if (!inplace && (a->grad || pw->grad || ph->grad)) { - is_node = true; - } - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); ggml_set_op_params_i32(result, 0, inplace ? 1 : 0); - result->op = GGML_OP_ADD_REL_POS; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_ADD_REL_POS; result->src[0] = a; result->src[1] = pw; result->src[2] = ph; @@ -7902,12 +7367,12 @@ struct ggml_tensor * ggml_add_rel_pos_inplace( struct ggml_tensor * ggml_rwkv_wkv( struct ggml_context * ctx, - struct ggml_tensor * k, - struct ggml_tensor * v, - struct ggml_tensor * r, - struct ggml_tensor * tf, - struct ggml_tensor * td, - struct ggml_tensor * state) { + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor * r, + struct ggml_tensor * tf, + struct ggml_tensor * td, + struct ggml_tensor * state) { GGML_ASSERT(ggml_is_contiguous(k)); GGML_ASSERT(ggml_is_contiguous(v)); GGML_ASSERT(ggml_is_contiguous(r)); @@ -7928,19 +7393,11 @@ struct ggml_tensor * ggml_rwkv_wkv( GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs); } - bool is_node = false; - - if (k->grad || v->grad || r->grad || tf->grad || td->grad || state->grad) { - GGML_ABORT("fatal error"); // TODO: implement backward - is_node = true; - } - // concat output and new_state const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 }; struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); - result->op = GGML_OP_RWKV_WKV; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_RWKV_WKV; result->src[0] = k; result->src[1] = v; result->src[2] = r; @@ -7955,23 +7412,16 @@ struct ggml_tensor * ggml_rwkv_wkv( static struct ggml_tensor * ggml_unary_impl( struct ggml_context * ctx, - struct ggml_tensor * a, - enum ggml_unary_op op, - bool inplace) { + struct ggml_tensor * a, + enum ggml_unary_op op, + bool inplace) { GGML_ASSERT(ggml_is_contiguous_1(a)); - bool is_node = false; - - if (!inplace && (a->grad)) { - is_node = true; - } - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); ggml_set_op_params_i32(result, 0, (int32_t) op); - result->op = GGML_OP_UNARY; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_UNARY; result->src[0] = a; return result; @@ -7980,14 +7430,14 @@ static struct ggml_tensor * ggml_unary_impl( struct ggml_tensor * ggml_unary( struct ggml_context * ctx, struct ggml_tensor * a, - enum ggml_unary_op op) { + enum ggml_unary_op op) { return ggml_unary_impl(ctx, a, op, false); } struct ggml_tensor * ggml_unary_inplace( struct ggml_context * ctx, struct ggml_tensor * a, - enum ggml_unary_op op) { + enum ggml_unary_op op) { return ggml_unary_impl(ctx, a, op, true); } @@ -7996,20 +7446,13 @@ struct ggml_tensor * ggml_unary_inplace( static struct ggml_tensor * ggml_map_unary_impl_f32( struct ggml_context * ctx, struct ggml_tensor * a, - const ggml_unary_op_f32_t fun, - bool inplace) { - bool is_node = false; - - if (!inplace && a->grad) { - is_node = true; - } - + const ggml_unary_op_f32_t fun, + bool inplace) { struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); ggml_set_op_params(result, (const void *) &fun, sizeof(fun)); - result->op = GGML_OP_MAP_UNARY; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_MAP_UNARY; result->src[0] = a; return result; @@ -8018,14 +7461,14 @@ static struct ggml_tensor * ggml_map_unary_impl_f32( struct ggml_tensor * ggml_map_unary_f32( struct ggml_context * ctx, struct ggml_tensor * a, - const ggml_unary_op_f32_t fun) { + const ggml_unary_op_f32_t fun) { return ggml_map_unary_impl_f32(ctx, a, fun, false); } struct ggml_tensor * ggml_map_unary_inplace_f32( struct ggml_context * ctx, struct ggml_tensor * a, - const ggml_unary_op_f32_t fun) { + const ggml_unary_op_f32_t fun) { return ggml_map_unary_impl_f32(ctx, a, fun, true); } @@ -8035,22 +7478,15 @@ static struct ggml_tensor * ggml_map_binary_impl_f32( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, - const ggml_binary_op_f32_t fun, - bool inplace) { + const ggml_binary_op_f32_t fun, + bool inplace) { GGML_ASSERT(ggml_are_same_shape(a, b)); - bool is_node = false; - - if (!inplace && (a->grad || b->grad)) { - is_node = true; - } - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); ggml_set_op_params(result, (const void *) &fun, sizeof(fun)); - result->op = GGML_OP_MAP_BINARY; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_MAP_BINARY; result->src[0] = a; result->src[1] = b; @@ -8061,7 +7497,7 @@ struct ggml_tensor * ggml_map_binary_f32( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, - const ggml_binary_op_f32_t fun) { + const ggml_binary_op_f32_t fun) { return ggml_map_binary_impl_f32(ctx, a, b, fun, false); } @@ -8069,7 +7505,7 @@ struct ggml_tensor * ggml_map_binary_inplace_f32( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, - const ggml_binary_op_f32_t fun) { + const ggml_binary_op_f32_t fun) { return ggml_map_binary_impl_f32(ctx, a, b, fun, true); } @@ -8079,19 +7515,12 @@ static struct ggml_tensor * ggml_map_custom1_impl_f32( struct ggml_context * ctx, struct ggml_tensor * a, const ggml_custom1_op_f32_t fun, - bool inplace) { - bool is_node = false; - - if (!inplace && a->grad) { - is_node = true; - } - + bool inplace) { struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); ggml_set_op_params(result, (const void *) &fun, sizeof(fun)); - result->op = GGML_OP_MAP_CUSTOM1_F32; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_MAP_CUSTOM1_F32; result->src[0] = a; return result; @@ -8118,19 +7547,12 @@ static struct ggml_tensor * ggml_map_custom2_impl_f32( struct ggml_tensor * a, struct ggml_tensor * b, const ggml_custom2_op_f32_t fun, - bool inplace) { - bool is_node = false; - - if (!inplace && (a->grad || b->grad)) { - is_node = true; - } - + bool inplace) { struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); ggml_set_op_params(result, (const void *) &fun, sizeof(fun)); - result->op = GGML_OP_MAP_CUSTOM2_F32; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_MAP_CUSTOM2_F32; result->src[0] = a; result->src[1] = b; @@ -8161,19 +7583,12 @@ static struct ggml_tensor * ggml_map_custom3_impl_f32( struct ggml_tensor * b, struct ggml_tensor * c, const ggml_custom3_op_f32_t fun, - bool inplace) { - bool is_node = false; - - if (!inplace && (a->grad || b->grad || c->grad)) { - is_node = true; - } - + bool inplace) { struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); ggml_set_op_params(result, (const void *) &fun, sizeof(fun)); - result->op = GGML_OP_MAP_CUSTOM3_F32; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_MAP_CUSTOM3_F32; result->src[0] = a; result->src[1] = b; result->src[2] = c; @@ -8201,26 +7616,20 @@ struct ggml_tensor * ggml_map_custom3_inplace_f32( // ggml_map_custom1 struct ggml_map_custom1_op_params { - ggml_custom1_op_t fun; - int n_tasks; - void * userdata; + ggml_custom1_op_t fun; + int n_tasks; + void * userdata; }; static struct ggml_tensor * ggml_map_custom1_impl( - struct ggml_context * ctx, - struct ggml_tensor * a, - const ggml_custom1_op_t fun, - int n_tasks, - void * userdata, - bool inplace) { + struct ggml_context * ctx, + struct ggml_tensor * a, + const ggml_custom1_op_t fun, + int n_tasks, + void * userdata, + bool inplace) { GGML_ASSERT(n_tasks == GGML_N_TASKS_MAX || n_tasks > 0); - bool is_node = false; - - if (!inplace && a->grad) { - is_node = true; - } - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); struct ggml_map_custom1_op_params params = { @@ -8230,55 +7639,48 @@ static struct ggml_tensor * ggml_map_custom1_impl( }; ggml_set_op_params(result, (const void *) ¶ms, sizeof(params)); - result->op = GGML_OP_MAP_CUSTOM1; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_MAP_CUSTOM1; result->src[0] = a; return result; } struct ggml_tensor * ggml_map_custom1( - struct ggml_context * ctx, - struct ggml_tensor * a, - const ggml_custom1_op_t fun, - int n_tasks, - void * userdata) { + struct ggml_context * ctx, + struct ggml_tensor * a, + const ggml_custom1_op_t fun, + int n_tasks, + void * userdata) { return ggml_map_custom1_impl(ctx, a, fun, n_tasks, userdata, false); } struct ggml_tensor * ggml_map_custom1_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a, - const ggml_custom1_op_t fun, - int n_tasks, - void * userdata) { + struct ggml_context * ctx, + struct ggml_tensor * a, + const ggml_custom1_op_t fun, + int n_tasks, + void * userdata) { return ggml_map_custom1_impl(ctx, a, fun, n_tasks, userdata, true); } // ggml_map_custom2 struct ggml_map_custom2_op_params { - ggml_custom2_op_t fun; - int n_tasks; - void * userdata; + ggml_custom2_op_t fun; + int n_tasks; + void * userdata; }; static struct ggml_tensor * ggml_map_custom2_impl( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - const ggml_custom2_op_t fun, - int n_tasks, - void * userdata, - bool inplace) { + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + const ggml_custom2_op_t fun, + int n_tasks, + void * userdata, + bool inplace) { GGML_ASSERT(n_tasks == GGML_N_TASKS_MAX || n_tasks > 0); - bool is_node = false; - - if (!inplace && (a->grad || b->grad)) { - is_node = true; - } - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); struct ggml_map_custom2_op_params params = { @@ -8288,8 +7690,7 @@ static struct ggml_tensor * ggml_map_custom2_impl( }; ggml_set_op_params(result, (const void *) ¶ms, sizeof(params)); - result->op = GGML_OP_MAP_CUSTOM2; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_MAP_CUSTOM2; result->src[0] = a; result->src[1] = b; @@ -8297,22 +7698,22 @@ static struct ggml_tensor * ggml_map_custom2_impl( } struct ggml_tensor * ggml_map_custom2( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - const ggml_custom2_op_t fun, - int n_tasks, - void * userdata) { + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + const ggml_custom2_op_t fun, + int n_tasks, + void * userdata) { return ggml_map_custom2_impl(ctx, a, b, fun, n_tasks, userdata, false); } struct ggml_tensor * ggml_map_custom2_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - const ggml_custom2_op_t fun, - int n_tasks, - void * userdata) { + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + const ggml_custom2_op_t fun, + int n_tasks, + void * userdata) { return ggml_map_custom2_impl(ctx, a, b, fun, n_tasks, userdata, true); } @@ -8325,22 +7726,16 @@ struct ggml_map_custom3_op_params { }; static struct ggml_tensor * ggml_map_custom3_impl( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - struct ggml_tensor * c, - const ggml_custom3_op_t fun, - int n_tasks, - void * userdata, - bool inplace) { + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c, + const ggml_custom3_op_t fun, + int n_tasks, + void * userdata, + bool inplace) { GGML_ASSERT(n_tasks == GGML_N_TASKS_MAX || n_tasks > 0); - bool is_node = false; - - if (!inplace && (a->grad || b->grad || c->grad)) { - is_node = true; - } - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); struct ggml_map_custom3_op_params params = { @@ -8350,8 +7745,7 @@ static struct ggml_tensor * ggml_map_custom3_impl( }; ggml_set_op_params(result, (const void *) ¶ms, sizeof(params)); - result->op = GGML_OP_MAP_CUSTOM3; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_MAP_CUSTOM3; result->src[0] = a; result->src[1] = b; result->src[2] = c; @@ -8360,44 +7754,38 @@ static struct ggml_tensor * ggml_map_custom3_impl( } struct ggml_tensor * ggml_map_custom3( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - struct ggml_tensor * c, - const ggml_custom3_op_t fun, - int n_tasks, - void * userdata) { + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c, + const ggml_custom3_op_t fun, + int n_tasks, + void * userdata) { return ggml_map_custom3_impl(ctx, a, b, c, fun, n_tasks, userdata, false); } struct ggml_tensor * ggml_map_custom3_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - struct ggml_tensor * c, - const ggml_custom3_op_t fun, - int n_tasks, - void * userdata) { + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c, + const ggml_custom3_op_t fun, + int n_tasks, + void * userdata) { return ggml_map_custom3_impl(ctx, a, b, c, fun, n_tasks, userdata, true); } // ggml_cross_entropy_loss struct ggml_tensor * ggml_cross_entropy_loss( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b) { + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { GGML_ASSERT(ggml_are_same_shape(a, b)); - bool is_node = false; - - if (a->grad || b->grad) { - is_node = true; - } struct ggml_tensor * result = ggml_new_tensor_1d(ctx, a->type, 1); - result->op = GGML_OP_CROSS_ENTROPY_LOSS; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_CROSS_ENTROPY_LOSS; result->src[0] = a; result->src[1] = b; @@ -8407,17 +7795,16 @@ struct ggml_tensor * ggml_cross_entropy_loss( // ggml_cross_entropy_loss_back struct ggml_tensor * ggml_cross_entropy_loss_back( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - struct ggml_tensor * c) { + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c) { GGML_ASSERT(ggml_are_same_shape(a, b)); GGML_ASSERT(ggml_is_scalar(c)); struct ggml_tensor * result = ggml_dup_tensor(ctx, a); - result->op = GGML_OP_CROSS_ENTROPY_LOSS_BACK; - result->grad = NULL; + result->op = GGML_OP_CROSS_ENTROPY_LOSS_BACK; result->src[0] = a; result->src[1] = b; result->src[2] = c; @@ -8435,7 +7822,7 @@ struct ggml_tensor * ggml_opt_step_adamw( float beta2, float eps, float wd) { - GGML_ASSERT(a->grad); + GGML_ASSERT(a->flags & GGML_TENSOR_FLAG_PARAM); GGML_ASSERT(alpha > 0.0f); GGML_ASSERT(beta1 >= 0.0f && beta1 <= 1.0f); GGML_ASSERT(beta2 >= 0.0f && beta2 <= 1.0f); @@ -8444,13 +7831,6 @@ struct ggml_tensor * ggml_opt_step_adamw( struct ggml_tensor * result = ggml_view_tensor(ctx, a); - result->op = GGML_OP_OPT_STEP_ADAMW; - result->grad = NULL; - result->src[0] = a; - result->src[1] = a->grad; - result->src[2] = ggml_dup_tensor(ctx, a->grad); - result->src[3] = ggml_dup_tensor(ctx, a->grad); - const int64_t iter = 1; memcpy(&result->op_params[0], &iter, sizeof(int64_t)); ggml_set_op_params_f32(result, 2, alpha); @@ -8459,26 +7839,17 @@ struct ggml_tensor * ggml_opt_step_adamw( ggml_set_op_params_f32(result, 5, eps); ggml_set_op_params_f32(result, 6, wd); + result->op = GGML_OP_OPT_STEP_ADAMW; + result->src[0] = a; + result->src[1] = a->grad; + result->src[2] = ggml_dup_tensor(ctx, a); + result->src[3] = ggml_dup_tensor(ctx, a); + return result; } //////////////////////////////////////////////////////////////////////////////// -void ggml_set_param(struct ggml_context * ctx, struct ggml_tensor * tensor) { - tensor->flags |= GGML_TENSOR_FLAG_PARAM; - - GGML_ASSERT(tensor->grad == NULL); - tensor->grad = ggml_dup_tensor(ctx, tensor); - ggml_format_name(tensor->grad, "%s (grad)", tensor->name); -} - -void ggml_set_loss(struct ggml_tensor * tensor) { - GGML_ASSERT(ggml_is_scalar(tensor)); - GGML_ASSERT(tensor->type == GGML_TYPE_F32); - GGML_ASSERT(tensor->grad); - tensor->flags |= GGML_TENSOR_FLAG_LOSS; -} - // ggml_compute_forward_dup static void ggml_compute_forward_dup_same_cont( @@ -18198,7 +17569,7 @@ void ggml_build_backward_gradient_checkpointing( struct ggml_tensor * * checkpoints, int n_checkpoints) { ggml_graph_cpy(gf, gb_tmp); - ggml_build_backward_expand(ctx, gf, gb_tmp, false, true); + ggml_build_backward_expand(ctx, gf, gb_tmp, false); if (n_checkpoints <= 0) { ggml_graph_cpy(gb_tmp, gb); @@ -18850,7 +18221,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor ggml_soft_max_back(ctx, tensor->grad, tensor), zero_table, acc_table); } - + GGML_ASSERT((!src1 || !src1->grad) && "backward pass for softmax mask not implemented"); } break; case GGML_OP_SOFT_MAX_BACK: { @@ -18891,6 +18262,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor beta_slow), zero_table, acc_table); } + GGML_ASSERT((!src2 || !src2->grad) && "gradients for freq factors not implemented"); } break; case GGML_OP_ROPE_BACK: { @@ -19012,6 +18384,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor } case GGML_OP_FLASH_ATTN_EXT: { + GGML_ABORT("FA backward pass not adapted after rework"); struct ggml_tensor * flash_grad = NULL; if (src0->grad || src1->grad || tensor->src[2]->grad) { int32_t t = ggml_get_op_params_i32(tensor, 0); @@ -19186,6 +18559,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor tensor->grad), zero_table, acc_table); } + GGML_ASSERT(!src1->grad && "backward pass for labels not implemented"); } break; case GGML_OP_CROSS_ENTROPY_LOSS_BACK: { @@ -19236,7 +18610,7 @@ static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * } } - if (node->op == GGML_OP_NONE && node->grad == NULL) { + if (node->op == GGML_OP_NONE && !(node->flags & GGML_TENSOR_FLAG_PARAM)) { // reached a leaf node, not part of the gradient graph (e.g. a constant) GGML_ASSERT(cgraph->n_leafs < cgraph->size); @@ -19254,9 +18628,6 @@ static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * } cgraph->nodes[cgraph->n_nodes] = node; - if (cgraph->grads) { - cgraph->grads[cgraph->n_nodes] = node->grad; - } cgraph->n_nodes++; } } @@ -19284,20 +18655,58 @@ void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * ggml_build_forward_impl(cgraph, tensor, true); } -void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool accumulate, bool keep) { +void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool accumulate) { GGML_ASSERT(gf->n_nodes > 0); GGML_ASSERT(gf->grads); - // if we are keeping the gradient graph, we have to detach the gradient nodes from the original graph - if (keep) { - for (int i = 0; i < gf->n_nodes; i++) { - struct ggml_tensor * node = gf->nodes[i]; + for (int i = 0; i < gf->n_nodes; ++i) { + struct ggml_tensor * node = gf->nodes[i]; - if (node->grad) { - node->grad = ggml_dup_tensor(ctx, node); - gf->grads[i] = node->grad; - } + bool needs_grad = node->flags & GGML_TENSOR_FLAG_PARAM; + bool ignore_src[GGML_MAX_SRC] = {false}; + switch (node->op) { + // gradients in node->src[0] for one reason or another have no effect on output gradients + case GGML_OP_IM2COL: // only used for its shape + case GGML_OP_IM2COL_BACK: // same as IM2COL + ignore_src[0] = true; + break; + case GGML_OP_UNARY: { + const enum ggml_unary_op uop = ggml_get_unary_op(node); + // SGN and STEP unary ops are piecewise constant + if (uop == GGML_UNARY_OP_SGN || uop == GGML_UNARY_OP_STEP) { + ignore_src[0] = true; + } + } break; + + // gradients in node->src[1] for one reason or another have no effect on output gradients + case GGML_OP_CPY: // gradients in CPY target are irrelevant + case GGML_OP_GET_ROWS: // row indices not differentiable + case GGML_OP_GET_ROWS_BACK: // same as for GET_ROWS + case GGML_OP_ROPE: // positions not differentiable + ignore_src[1] = true; + break; + + default: + break; } + for (int j = 0; j < GGML_MAX_SRC; ++j) { + if (!node->src[j] || !node->src[j]->grad || ignore_src[j]) { + continue; + } + GGML_ASSERT(node->src[j]->type == GGML_TYPE_F32 || node->src[j]->type == GGML_TYPE_F16); + needs_grad = true; + break; + } + if (!needs_grad) { + continue; + } + + // inplace operations are currently not supported + GGML_ASSERT(!node->view_src || node->op == GGML_OP_CPY || node->op == GGML_OP_VIEW || + node->op == GGML_OP_RESHAPE || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_TRANSPOSE); + + // create a new tensor with the same type and shape as the node and set it as grad + node->grad = ggml_dup_tensor(ctx, node); } // keep tables of original gradients for replacement/accumulation logic @@ -22162,8 +21571,6 @@ enum ggml_opt_result ggml_opt( struct ggml_context * ctx, struct ggml_opt_params params, struct ggml_tensor * f) { - GGML_ASSERT(f->grad && "ggml_set_param called for at least one parent tensor."); - bool free_ctx = false; if (ctx == NULL) { struct ggml_init_params params_ctx = { @@ -22204,7 +21611,7 @@ enum ggml_opt_result ggml_opt_resume( ggml_build_forward_expand(gf, f); struct ggml_cgraph * gb = ggml_graph_dup(ctx, gf); - ggml_build_backward_expand(ctx, gf, gb, false, true); + ggml_build_backward_expand(ctx, gf, gb, false); return ggml_opt_resume_g(ctx, opt, f, gf, gb, NULL, NULL); } @@ -22257,6 +21664,17 @@ void ggml_set_output(struct ggml_tensor * tensor) { tensor->flags |= GGML_TENSOR_FLAG_OUTPUT; } +void ggml_set_param(struct ggml_context * ctx, struct ggml_tensor * tensor) { + GGML_UNUSED(ctx); // TODO: remove this parameter + tensor->flags |= GGML_TENSOR_FLAG_PARAM; +} + +void ggml_set_loss(struct ggml_tensor * tensor) { + GGML_ASSERT(ggml_is_scalar(tensor)); + GGML_ASSERT(tensor->type == GGML_TYPE_F32); + tensor->flags |= GGML_TENSOR_FLAG_LOSS; +} + //////////////////////////////////////////////////////////////////////////////// void ggml_quantize_init(enum ggml_type type) { diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index d2cfe06b5..5c78b6704 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1,6 +1,6 @@ // This file defines tests for various GGML ops and backends. // For the forward pass it asserts that the results of multiple backends computing the same GGML ops are consistent. -// For the backwards pass it asserts that the gradients from backpropagation are consistent +// For the backward pass it asserts that the gradients from backpropagation are consistent // with the gradients obtained via the method of finite differences ("grad" mode, this is optional). // It is also possible to check the performance ("perf" mode). // @@ -740,7 +740,7 @@ struct test_case { ggml_tensor * out = build_graph(ctx); - if (op_name != nullptr && op_desc(out) != op_name) { + if ((op_name != nullptr && op_desc(out) != op_name) || out->op == GGML_OP_OPT_STEP_ADAMW) { //printf(" %s: skipping\n", op_desc(out).c_str()); ggml_free(ctx); return true; @@ -749,11 +749,6 @@ struct test_case { printf(" %s(%s): ", op_desc(out).c_str(), vars().c_str()); fflush(stdout); - if (out->grad == nullptr) { - printf("backwards pass not supported \n"); - ggml_free(ctx); - return true; - } if (out->type != GGML_TYPE_F32) { ggml_free(ctx); printf("not supported [%s->type != FP32]\n", out->name); @@ -762,18 +757,26 @@ struct test_case { // check if the backend supports the ops bool supported = true; + bool any_params = false; for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { if (!ggml_backend_supports_op(backend, t)) { printf("not supported [%s] ", ggml_backend_name(backend)); supported = false; break; } - if ((t->flags & GGML_TENSOR_FLAG_PARAM) && t->type != GGML_TYPE_F32) { - printf("not supported [%s->type != FP32] ", t->name); - supported = false; - break; + if ((t->flags & GGML_TENSOR_FLAG_PARAM)) { + any_params = true; + if (t->type != GGML_TYPE_F32) { + printf("not supported [%s->type != FP32] ", t->name); + supported = false; + break; + } } } + if (!any_params) { + printf("not supported [%s] \n", op_name); + supported = false; + } if (!supported) { printf("\n"); ggml_free(ctx); @@ -801,7 +804,7 @@ struct test_case { ggml_build_forward_expand(gf, out); ggml_graph_cpy(gf, gb); - ggml_build_backward_expand(ctx, gf, gb, false, false); + ggml_build_backward_expand(ctx, gf, gb, false); if (expect.size() != 1 || expect[0] != 0.0f) { GGML_ASSERT(ggml_graph_n_nodes(gb) > ggml_graph_n_nodes(gf)); for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { @@ -984,7 +987,7 @@ struct test_example : public test_case { } // In order to also check the gradients for your op, add calls like ggml_set_param(ctx, a) // immediately after you create the tensors. - // This is optional and only makes sense if a backwards pass has actually been implemented for the new op. + // This is optional and only makes sense if a backward pass has actually been implemented for the new op. }; @@ -1223,7 +1226,7 @@ struct test_set : public test_case { offset += ((ne_dst[i] - ne[i])/2)*dst->nb[i]; } ggml_tensor * out = ggml_set(ctx, dst, src, - // The backwards pass requires setting a contiguous region: + // The backward pass requires setting a contiguous region: src->nb[1], src->nb[2], src->nb[3], offset); ggml_set_name(out, "out"); @@ -1335,7 +1338,7 @@ struct test_bin_bcast : public test_case { ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data()); ggml_set_name(b, "b"); - // The backwards pass supports broadcasting only for GGML_ADD: + // The backward pass supports broadcasting only for GGML_ADD: const bool grad_supported = op == ggml_add || ggml_are_same_shape(a, b); if (grad_supported) { ggml_set_param(ctx, a); @@ -1830,7 +1833,7 @@ struct test_log : public test_case { void initialize_tensors(ggml_context * ctx) override { for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { - // log(1) == 0, cluster values there to keep the sum low for better precision in the backwards pass: + // log(1) == 0, cluster values there to keep the sum low for better precision in the backward pass: init_tensor_uniform(t, 0.9f, 1.1f); } } @@ -3257,7 +3260,7 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,1,2,1}, 1, 0, 1)); test_cases.emplace_back(new test_conv_transpose_1d({2,1,1,1}, {3,1,1,1}, 1, 0, 1)); - for (int ne3 : {1, 3}) { // CUDA backwards pass only supports ne3 == 1 + for (int ne3 : {1, 3}) { // CUDA backward pass only supports ne3 == 1 test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, ne3}, {1, 1, 1, 1})); test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, ne3}, {2, 1, 1, 1})); test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, ne3}, {1, 2, 1, 1})); diff --git a/tests/test-grad0.cpp b/tests/test-grad0.cpp index 2ef606d2c..2200ad93d 100644 --- a/tests/test-grad0.cpp +++ b/tests/test-grad0.cpp @@ -240,12 +240,14 @@ static bool check_gradient( struct ggml_cgraph * gb = ggml_new_graph_custom(ctx0, GGML_DEFAULT_GRAPH_SIZE, true); ggml_build_forward_expand(gf, f); ggml_graph_cpy(gf, gb); - ggml_build_backward_expand(ctx0, gf, gb, false, false); + ggml_build_backward_expand(ctx0, gf, gb, false); ggml_graph_compute_with_ctx(ctx0, gf, n_threads); - ggml_graph_reset (gf); - ggml_set_f32 (f->grad, 1.0f); + ggml_graph_reset(gb); + if (f->grad) { + ggml_set_f32(f->grad, 1.0f); + } ggml_graph_compute_with_ctx(ctx0, gb, n_threads); @@ -298,8 +300,10 @@ static bool check_gradient( ggml_set_f32_1d(x[i], k, x0); // compute gradient using backward graph - ggml_graph_reset (gf); - ggml_set_f32 (f->grad, 1.0f); + ggml_graph_reset(gb); + if (f->grad) { + ggml_set_f32(f->grad, 1.0f); + } ggml_graph_compute_with_ctx(ctx0, gb, n_threads); From 6c5322481a75f1be54e58a000c3f78484d07f948 Mon Sep 17 00:00:00 2001 From: Borislav Stanimirov Date: Mon, 30 Sep 2024 10:11:41 +0300 Subject: [PATCH 30/33] ggml : fix ggml_cast (ggml/973) --- ggml/src/ggml.c | 1 + 1 file changed, 1 insertion(+) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index fa8d6c25a..aac4e3a7b 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -5818,6 +5818,7 @@ struct ggml_tensor * ggml_cast( result->op = GGML_OP_CPY; result->src[0] = a; + result->src[1] = result; return result; } From cb00020504416606601f8cb35f55ee07b710b4b7 Mon Sep 17 00:00:00 2001 From: Salvatore Mesoraca Date: Mon, 30 Sep 2024 09:14:09 +0200 Subject: [PATCH 31/33] vulkan : mul_mat: fix UB with small warps (ggml/952) When the device's warp size is less than 16, it is possible for loadstride_a (mul_mm.comp:114) and loadstride_b (mul_mm.comp:115) to be set to 0. Because they are calculated as: the workgroup size, multiplied by LOAD_VEC_* (which can be 1) and divided by 16. And the workgroup size is set to be the same as the warp/subgroup size. The loadstride_* variables are used as increments in the loops that populate the buffers used for the multiplication. When they are 0 they cause an infinite loop. But infinite loops without side-effects are UB and the values of loadstride_* are known at compile time. So, the compiler quietly optimizes all the loops away. As a consequence, the buffers are not populated and the multiplication result is just a matrix with all elements set to 0. We prevent the UB by making sure that the workgroup size will never be less than 16, even if our device has a smaller warp size (e.g. 8). Signed-off-by: Salvatore Mesoraca --- ggml/src/ggml-vulkan.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-vulkan.cpp b/ggml/src/ggml-vulkan.cpp index c677a2728..00ad13bb9 100644 --- a/ggml/src/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan.cpp @@ -1164,11 +1164,11 @@ static void ggml_vk_load_shaders(vk_device& device) { // mulmat std::initializer_list warptile_l = { 128, 128, 128, 16, device->subgroup_size * 2, 64, 2, 4, 4, device->subgroup_size }; std::initializer_list warptile_m = { 128, 64, 64, 16, device->subgroup_size, 32, 2, 4, 2, device->subgroup_size }; - std::initializer_list warptile_s = { device->subgroup_size, 32, 32, 16, 32, 32, 2, 2, 2, device->subgroup_size }; + std::initializer_list warptile_s = { std::max(device->subgroup_size, 16u), 32, 32, 16, 32, 32, 2, 2, 2, device->subgroup_size }; std::initializer_list warptile_mmq_l = { 128, 128, 128, 32, device->subgroup_size * 2, 64, 2, 4, 4, device->subgroup_size }; std::initializer_list warptile_mmq_m = { 128, 64, 64, 32, device->subgroup_size, 32, 2, 4, 2, device->subgroup_size }; - std::initializer_list warptile_mmq_s = { device->subgroup_size, 32, 32, 32, 32, 32, 2, 2, 2, device->subgroup_size }; + std::initializer_list warptile_mmq_s = { std::max(device->subgroup_size, 16u), 32, 32, 32, 32, 32, 2, 2, 2, device->subgroup_size }; std::array l_wg_denoms = {128, 128, 1 }; std::array m_wg_denoms = { 64, 64, 1 }; From e98c1c188ebbeb55d0cd6389ad03f0cbf995b451 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Mon, 30 Sep 2024 09:55:23 +0200 Subject: [PATCH 32/33] test: fix OPT_STEP_ADAMW for test-backend-ops (ggml/974) --- ggml/include/ggml.h | 1 + ggml/src/ggml.c | 10 ++++++---- tests/test-backend-ops.cpp | 5 ++++- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index a8a74bee1..ce3d92cb2 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -2052,6 +2052,7 @@ extern "C" { GGML_API struct ggml_tensor * ggml_opt_step_adamw( struct ggml_context * ctx, struct ggml_tensor * a, + struct ggml_tensor * grad, float alpha, float beta1, float beta2, diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index aac4e3a7b..bcbc32d91 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -7818,12 +7818,14 @@ struct ggml_tensor * ggml_cross_entropy_loss_back( struct ggml_tensor * ggml_opt_step_adamw( struct ggml_context * ctx, struct ggml_tensor * a, + struct ggml_tensor * grad, float alpha, float beta1, float beta2, float eps, float wd) { GGML_ASSERT(a->flags & GGML_TENSOR_FLAG_PARAM); + GGML_ASSERT(ggml_are_same_shape(a, grad)); GGML_ASSERT(alpha > 0.0f); GGML_ASSERT(beta1 >= 0.0f && beta1 <= 1.0f); GGML_ASSERT(beta2 >= 0.0f && beta2 <= 1.0f); @@ -7842,9 +7844,9 @@ struct ggml_tensor * ggml_opt_step_adamw( result->op = GGML_OP_OPT_STEP_ADAMW; result->src[0] = a; - result->src[1] = a->grad; - result->src[2] = ggml_dup_tensor(ctx, a); - result->src[3] = ggml_dup_tensor(ctx, a); + result->src[1] = grad; + result->src[2] = ggml_dup_tensor(ctx, grad); + result->src[3] = ggml_dup_tensor(ctx, grad); return result; } @@ -18769,7 +18771,7 @@ void ggml_build_opt_adamw( if (node->flags & GGML_TENSOR_FLAG_PARAM) { GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node); - struct ggml_tensor * opt_step = ggml_opt_step_adamw(ctx, node, alpha, beta1, beta2, eps, wd); + struct ggml_tensor * opt_step = ggml_opt_step_adamw(ctx, node, node->grad, alpha, beta1, beta2, eps, wd); ggml_build_forward_expand(gb, opt_step); } } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 5c78b6704..95d983aa0 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -2751,7 +2751,10 @@ struct test_opt_step_adamw : public test_case { ggml_set_param(ctx, a); // Despite tensor a having gradients the output tensor will not. ggml_set_name(a, "a"); - ggml_tensor * out = ggml_opt_step_adamw(ctx, a, alpha, beta1, beta2, eps, wd); + ggml_tensor * grad = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]); + ggml_set_name(grad, "grad"); + + ggml_tensor * out = ggml_opt_step_adamw(ctx, a, grad, alpha, beta1, beta2, eps, wd); ggml_set_name(out, "out"); return out; From f1b8c4271125c2bfb9cfebd72a8b9e9f99061a30 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 1 Oct 2024 16:09:42 +0300 Subject: [PATCH 33/33] sync : ggml --- scripts/sync-ggml.last | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index aa301462a..23c24899e 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -9a24b8c8c40eab7262d067e91d08df160678df8d +4de6ee8e6a4b2145d6b92162bc87722fecb4ea46