From 0ab30f8d82fc7156b750c194d64a887e80cbfb82 Mon Sep 17 00:00:00 2001 From: Sutou Kouhei Date: Sat, 31 Aug 2024 03:08:10 +0900 Subject: [PATCH 01/22] llama : fix llama_split_mode enum values in main_gpu document (#9057) LLAMA_SPLIT_* were renamed to LLAMA_SPLIT_MODE_* in #5697. --- include/llama.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/include/llama.h b/include/llama.h index c3bda9e02..f2e701602 100644 --- a/include/llama.h +++ b/include/llama.h @@ -267,9 +267,9 @@ extern "C" { enum llama_split_mode split_mode; // how to split the model across multiple GPUs // main_gpu interpretation depends on split_mode: - // LLAMA_SPLIT_NONE: the GPU that is used for the entire model - // LLAMA_SPLIT_ROW: the GPU that is used for small tensors and intermediate results - // LLAMA_SPLIT_LAYER: ignored + // LLAMA_SPLIT_MODE_NONE: the GPU that is used for the entire model + // LLAMA_SPLIT_MODE_ROW: the GPU that is used for small tensors and intermediate results + // LLAMA_SPLIT_MODE_LAYER: ignored int32_t main_gpu; // proportion of the model (layers or rows) to offload to each GPU, size: llama_max_devices() From 49271efbaf3f7a4ae52c4ca299b6d4e82598a97d Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Sat, 31 Aug 2024 09:50:22 +0200 Subject: [PATCH 02/22] llama : fix typo in xcda_array_view comment [no ci] (#9132) --- src/llama-vocab.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index 323660ef5..dc4138cbc 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -963,7 +963,7 @@ private: /* * This structure is a view wrapper for XOR-compressed double array (XCDA) * See Shunsuke Kanda (2018). Space- and Time-Efficient String Dictionaries. - * Eeach bit-packed entry contains: + * Each bit-packed entry contains: * - BASE array value in bits 10-30 * - LCHECK array value in bits 0-7 * - LEAF array value in bit 9 From ea5d7478b1edcfc83f8ea8f9f0934585cc0b92e3 Mon Sep 17 00:00:00 2001 From: Srihari-mcw <96763064+Srihari-mcw@users.noreply.github.com> Date: Sat, 31 Aug 2024 13:50:35 +0530 Subject: [PATCH 03/22] sgemm : improved Q4_0 and Q8_0 performance via 4xN and Mx4 gemm (#8908) --- ggml/src/llamafile/sgemm.cpp | 149 +++++++++++++++++++++++++++++++++++ 1 file changed, 149 insertions(+) diff --git a/ggml/src/llamafile/sgemm.cpp b/ggml/src/llamafile/sgemm.cpp index 6626ceb26..f0988ba7c 100644 --- a/ggml/src/llamafile/sgemm.cpp +++ b/ggml/src/llamafile/sgemm.cpp @@ -606,17 +606,29 @@ class tinyBLAS_Q0_AVX { case 0x44: mc = 4; nc = 4; +#if defined(__AVX2__) && defined(__F16C__) + gemm4xN<4>(m0, m, n0, n); +#else gemm<4, 4>(m0, m, n0, n); +#endif break; case 0x43: mc = 4; nc = 3; +#if defined(__AVX2__) && defined(__F16C__) + gemm4xN<3>(m0, m, n0, n); +#else gemm<4, 3>(m0, m, n0, n); +#endif break; case 0x34: mc = 3; nc = 4; +#if defined(__AVX2__) && defined(__F16C__) + gemmMx4<3>(m0, m, n0, n); +#else gemm<3, 4>(m0, m, n0, n); +#endif break; case 0x33: mc = 3; @@ -626,12 +638,20 @@ class tinyBLAS_Q0_AVX { case 0x42: mc = 4; nc = 2; +#if defined(__AVX2__) && defined(__F16C__) + gemm4xN<2>(m0, m, n0, n); +#else gemm<4, 2>(m0, m, n0, n); +#endif break; case 0x24: mc = 2; nc = 4; +#if defined(__AVX2__) && defined(__F16C__) + gemmMx4<2>(m0, m, n0, n); +#else gemm<2, 4>(m0, m, n0, n); +#endif break; #else case 0x44: @@ -639,13 +659,21 @@ class tinyBLAS_Q0_AVX { case 0x42: mc = 4; nc = 2; +#if defined(__AVX2__) && defined(__F16C__) + gemm4xN<2>(m0, m, n0, n); +#else gemm<4, 2>(m0, m, n0, n); +#endif break; case 0x34: case 0x24: mc = 2; nc = 4; +#if defined(__AVX2__) && defined(__F16C__) + gemmMx4<2>(m0, m, n0, n); +#else gemm<2, 4>(m0, m, n0, n); +#endif break; case 0x33: #endif @@ -662,7 +690,11 @@ class tinyBLAS_Q0_AVX { case 0x41: mc = 4; nc = 1; +#if defined(__AVX2__) && defined(__F16C__) + gemm4xN<1>(m0, m, n0, n); +#else gemm<4, 1>(m0, m, n0, n); +#endif break; case 0x22: mc = 2; @@ -672,7 +704,11 @@ class tinyBLAS_Q0_AVX { case 0x14: mc = 1; nc = 4; +#if defined(__AVX2__) && defined(__F16C__) + gemmMx4<1>(m0, m, n0, n); +#else gemm<1, 4>(m0, m, n0, n); +#endif break; case 0x31: mc = 3; @@ -708,6 +744,119 @@ class tinyBLAS_Q0_AVX { mnpack(m0, m, np, n); } +#if defined(__AVX2__) && defined(__F16C__) +// Templated functions for gemm of dimensions 4xN + template + NOINLINE void gemm4xN(int64_t m0, int64_t m, int64_t n0, int64_t n) { + int64_t ytiles = (m - m0) / 4; + int64_t xtiles = (n - n0) / RN; + int64_t tiles = xtiles * ytiles; + int64_t duty = (tiles + nth - 1) / nth; + int64_t start = duty * ith; + int64_t end = start + duty; + if (end > tiles) + end = tiles; + for (int64_t job = start; job < end; ++job) { + int64_t ii = m0 + job / xtiles * 4; + int64_t jj = n0 + job % xtiles * RN; + __m256 Cv[RN][4] = {}; + for (int64_t l = 0; l < k; ++l) { + uint64_t a_delta = ((uint64_t)A[lda * (ii + 3) + l].d << 48) | ((uint64_t)A[lda * (ii + 2) + l].d << 32) | ((uint64_t)A[lda * (ii + 1) + l].d << 16) | (A[lda * (ii + 0) + l].d); + // Convert delta values for four blocks to float values + __m128 da = _mm_cvtph_ps(_mm_set_epi64x(0, a_delta)); + __m256i avec0 = load(A + lda * (ii + 0) + l); + __m256i avec1 = load(A + lda * (ii + 1) + l); + __m256i avec2 = load(A + lda * (ii + 2) + l); + __m256i avec3 = load(A + lda * (ii + 3) + l); + for (int64_t j = 0; j < RN; ++j) { + __m128 db = _mm_set1_ps(unhalf(B[ldb * (jj + j) + l].d)); + // Computation of product of delta values for four blocks and replicate it across 256 bit lane + __m256 dvec = _mm256_castps128_ps256(_mm_mul_ps(da, db)); + dvec = _mm256_permute2f128_ps(dvec ,dvec, 0); + // Computation of dot product and multiplication with appropriate delta value products + Cv[j][0] = madd(_mm256_shuffle_ps(dvec, dvec, 0), + updot(_mm256_sign_epi8(avec0, avec0), + _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec0)), + Cv[j][0]); + Cv[j][1] = madd(_mm256_shuffle_ps(dvec, dvec, 85), + updot(_mm256_sign_epi8(avec1, avec1), + _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec1)), + Cv[j][1]); + Cv[j][2] = madd(_mm256_shuffle_ps(dvec, dvec, 170), + updot(_mm256_sign_epi8(avec2, avec2), + _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec2)), + Cv[j][2]); + Cv[j][3] = madd(_mm256_shuffle_ps(dvec, dvec, 255), + updot(_mm256_sign_epi8(avec3, avec3), + _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec3)), + Cv[j][3]); + } + } + + for (int64_t j = 0; j < RN; ++j) + for (int64_t i = 0; i < 4; ++i) + C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]); + } + } + + // Templated functions for gemm of dimensions Mx4 + template + NOINLINE void gemmMx4(int64_t m0, int64_t m, int64_t n0, int64_t n) { + int64_t ytiles = (m - m0) / RM; + int64_t xtiles = (n - n0) / 4; + int64_t tiles = xtiles * ytiles; + int64_t duty = (tiles + nth - 1) / nth; + int64_t start = duty * ith; + int64_t end = start + duty; + if (end > tiles) + end = tiles; + for (int64_t job = start; job < end; ++job) { + int64_t ii = m0 + job / xtiles * RM; + int64_t jj = n0 + job % xtiles * 4; + __m256 Cv[4][RM] = {}; + for (int64_t l = 0; l < k; ++l) { + uint64_t b_delta = ((uint64_t)B[ldb * (jj + 3) + l].d << 48) | ((uint64_t)B[ldb * (jj + 2) + l].d << 32) | ((uint64_t)B[ldb * (jj + 1) + l].d << 16) | (B[ldb * (jj + 0) + l].d); + // Convert delta values for four blocks to float values + __m128 db = _mm_cvtph_ps(_mm_set_epi64x(0, b_delta)); + __m256i bvec0 = load(B + ldb * (jj + 0) + l); + __m256i bvec1 = load(B + ldb * (jj + 1) + l); + __m256i bvec2 = load(B + ldb * (jj + 2) + l); + __m256i bvec3 = load(B + ldb * (jj + 3) + l); + for (int64_t i = 0; i < RM; ++i) { + __m128 da = _mm_set1_ps(unhalf((A[lda * (ii + i) + l].d))); + // Computation of product of delta values for four blocks and replicate it across 256 bit lane + __m256 dvec = _mm256_castps128_ps256(_mm_mul_ps(da, db)); + dvec = _mm256_permute2f128_ps(dvec ,dvec, 0); + // Computation of dot product and multiplication with appropriate delta value products + Cv[0][i] = madd(_mm256_shuffle_ps(dvec, dvec, 0), + updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l), + load(A + lda * (ii + i) + l)), + _mm256_sign_epi8(bvec0, load(A + lda * (ii + i) + l))), + Cv[0][i]); + Cv[1][i] = madd(_mm256_shuffle_ps(dvec, dvec, 85), + updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l), + load(A + lda * (ii + i) + l)), + _mm256_sign_epi8(bvec1, load(A + lda * (ii + i) + l))), + Cv[1][i]); + Cv[2][i] = madd(_mm256_shuffle_ps(dvec, dvec, 170), + updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l), + load(A + lda * (ii + i) + l)), + _mm256_sign_epi8(bvec2, load(A + lda * (ii + i) + l))), + Cv[2][i]); + Cv[3][i] = madd(_mm256_shuffle_ps(dvec, dvec, 255), + updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l), + load(A + lda * (ii + i) + l)), + _mm256_sign_epi8(bvec3, load(A + lda * (ii + i) + l))), + Cv[3][i]); + } + } + for (int64_t j = 0; j < 4; ++j) + for (int64_t i = 0; i < RM; ++i) + C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]); + } + } +#endif + template NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) { int64_t ytiles = (m - m0) / RM; From a47667cff41f5a198eb791974e0afcc1cddd3229 Mon Sep 17 00:00:00 2001 From: Echo Nolan Date: Thu, 22 Aug 2024 17:19:14 -0400 Subject: [PATCH 04/22] nix: fix CUDA build - replace deprecated autoAddOpenGLRunpathHook The CUDA nix build broke when we updated nixpkgs in 8cd1bcfd3fc9f2b5cbafd7fb7581b3278acec25f. As far as I can tell all that happened is cudaPackages.autoAddOpenGLRunpathHook got moved to pkgs.autoAddDriverRunpath. This commit fixes it. --- .devops/nix/package.nix | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/.devops/nix/package.nix b/.devops/nix/package.nix index a87423c71..cfffac257 100644 --- a/.devops/nix/package.nix +++ b/.devops/nix/package.nix @@ -13,6 +13,7 @@ mpi, blas, cudaPackages, + autoAddDriverRunpath, darwin, rocmPackages, vulkan-headers, @@ -192,10 +193,7 @@ effectiveStdenv.mkDerivation ( ] ++ optionals useCuda [ cudaPackages.cuda_nvcc - - # TODO: Replace with autoAddDriverRunpath - # once https://github.com/NixOS/nixpkgs/pull/275241 has been merged - cudaPackages.autoAddOpenGLRunpathHook + autoAddDriverRunpath ] ++ optionals (effectiveStdenv.hostPlatform.isGnu && enableStatic) [ glibc.static From 8f1d81a0b6f50b9bad72db0b6fcd299ad9ecd48c Mon Sep 17 00:00:00 2001 From: Molly Sophia Date: Sun, 1 Sep 2024 22:38:17 +0800 Subject: [PATCH 05/22] llama : support RWKV v6 models (#8980) * convert_hf_to_gguf: Add support for RWKV v6 Signed-off-by: Molly Sophia * Add RWKV tokenization * Fix build Signed-off-by: Molly Sophia * Do not use special tokens when matching in RWKV tokenizer * Fix model loading * Add (broken) placeholder graph builder for RWKV * Add workaround for kv cache * Add logits conversion to rwkv5 * Add rwkv5 layer norms * Add time mix KVRG & correct merge mistake * Add remaining time mix parameters * Add time mix output loading * Add placeholder llm_build_time_mix * Fix build Signed-off-by: Molly Sophia * Load more tensors for rwkv v6 Signed-off-by: Molly Sophia * Fix rwkv tokenizer Signed-off-by: Molly Sophia * ggml: Add unary operator Exp Signed-off-by: Molly Sophia * RWKV v6 graph building Signed-off-by: Molly Sophia * Add ``rescale_every_n_layers`` parameter Signed-off-by: Molly Sophia * Add ``wkv.head_size`` key for RWKV so it doesn't reuse Mamba ssm parameters Signed-off-by: Molly Sophia * Fix offloading layers to CUDA Signed-off-by: Molly Sophia * Fix parallel inferencing for RWKV Signed-off-by: Molly Sophia * Remove trailing whitespaces Signed-off-by: Molly Sophia * build_rwkv: Avoid using inplace operations Signed-off-by: Molly Sophia * convert_hf_to_gguf: rwkv: Avoid using ``eval`` Signed-off-by: Molly Sophia * convert_hf_to_gguf: rwkv tokenizer: Don't escape sequences manually Signed-off-by: Molly Sophia * Update convert_hf_to_gguf.py Co-authored-by: compilade * ggml: Add backward computation for unary op ``exp`` Signed-off-by: Molly Sophia * Update convert_hf_to_gguf.py Co-authored-by: compilade * Update convert_hf_to_gguf.py Co-authored-by: compilade * Use MODEL_ARCH.RWKV6 instead of MODEL_ARCH.RWKV Signed-off-by: Molly Sophia * build_rwkv6: Simplify graph Signed-off-by: Molly Sophia * llama: rwkv6: Detect model.type Signed-off-by: Molly Sophia * llama: rwkv6: Fix tensor loading for 7B/14B models Signed-off-by: Molly Sophia * llama: rwkv6: Fix group_norm assertion failure with Metal Signed-off-by: Molly Sophia * llama: rwkv6: Clean up Signed-off-by: Molly Sophia * llama: rwkv6: Add quantization tensor exclusion Signed-off-by: Molly Sophia * llama: rwkv6: Use the new advanced batch splits Signed-off-by: Molly Sophia * Update src/llama.cpp Co-authored-by: compilade * llama: rwkv6: Use ``ggml_norm`` instead of ``ggml_group_norm`` Co-authored-by: compilade * llama: rwkv6: Apply code style and misc changes Signed-off-by: Molly Sophia * converter: Use class name ``Rwkv6Model`` Signed-off-by: Molly Sophia * llama: rwkv6: Make use of key ``feed_forward_length`` Signed-off-by: Molly Sophia * llama: rwkv6: Add kv ``time_mix_extra_dim`` and ``time_decay_extra_dim`` Signed-off-by: Molly Sophia * converter: Match ``new_name`` instead of ``name`` for float32 explicit tensors Signed-off-by: Molly Sophia * llama: rwkv6: Keep ``time_mix_w1/w2`` as F32 Signed-off-by: Molly Sophia * llama: rwkv6: Remove unused nodes Signed-off-by: Molly Sophia * llama: rwkv6: Apply code format changes Signed-off-by: Molly Sophia * llama: rwkv6: Add lora for some supported tensors Currently att.key/receptance/value/gate/output, ffn.receptance/key/value, as well as head.weight Signed-off-by: Molly Sophia * rwkv : speed-up tokenization using trie * minor : style + indentation * llama: rwkv6: Avoid division by zero Co-authored-by: compilade * ggml: rwkv_wkv: Avoid copying the state Signed-off-by: Molly Sophia --------- Signed-off-by: Molly Sophia Co-authored-by: Layl Bongers <3094382+LaylBongers@users.noreply.github.com> Co-authored-by: compilade Co-authored-by: Georgi Gerganov --- convert_hf_to_gguf.py | 84 +++++- ggml/include/ggml.h | 19 ++ ggml/src/ggml.c | 228 +++++++++++++- gguf-py/gguf/constants.py | 247 ++++++++++----- gguf-py/gguf/gguf_writer.py | 12 + gguf-py/gguf/tensor_mapping.py | 100 ++++++- include/llama.h | 1 + src/llama-vocab.cpp | 145 ++++++++- src/llama.cpp | 533 ++++++++++++++++++++++++++++++++- 9 files changed, 1266 insertions(+), 103 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index caa41aee5..27ac34b81 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -3,6 +3,7 @@ from __future__ import annotations +import ast import logging import argparse import contextlib @@ -298,9 +299,12 @@ class Model: gguf.MODEL_TENSOR.POS_EMBD, gguf.MODEL_TENSOR.TOKEN_TYPES, gguf.MODEL_TENSOR.SSM_CONV1D, + gguf.MODEL_TENSOR.TIME_MIX_FIRST, + gguf.MODEL_TENSOR.TIME_MIX_W1, + gguf.MODEL_TENSOR.TIME_MIX_W2, ) ) - or not name.endswith(".weight") + or not new_name.endswith(".weight") ): data_qtype = gguf.GGMLQuantizationType.F32 @@ -2716,6 +2720,84 @@ class StarCoder2Model(Model): model_arch = gguf.MODEL_ARCH.STARCODER2 +@Model.register("Rwkv6ForCausalLM") +class Rwkv6Model(Model): + model_arch = gguf.MODEL_ARCH.RWKV6 + + def set_vocab(self): + assert (self.dir_model / "rwkv_vocab_v20230424.txt").is_file() + vocab_size = self.hparams.get("vocab_size", 65536) + + tokens: list[bytes] = [''.encode("utf-8")] + toktypes: list[int] = [gguf.TokenType.CONTROL] + + with open(self.dir_model / "rwkv_vocab_v20230424.txt", "r", encoding="utf-8") as f: + lines = f.readlines() + for line in lines: + parts = line.split(' ') + assert len(parts) >= 3 + token, token_len = ast.literal_eval(' '.join(parts[1:-1])), int(parts[-1]) + token = token.encode("utf-8") if isinstance(token, str) else token + assert isinstance(token, bytes) + assert len(token) == token_len + token_text: str = repr(token)[2:-1] # "b'\xff'" -> "\xff" + tokens.append(token_text.encode("utf-8")) + toktypes.append(gguf.TokenType.NORMAL) + remainder = vocab_size - len(tokens) + assert remainder >= 0 + for i in range(len(tokens), vocab_size): + tokens.append(f"[PAD{i}]".encode("utf-8")) + toktypes.append(gguf.TokenType.UNUSED) + + self.gguf_writer.add_tokenizer_model("rwkv") + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_types(toktypes) + + def set_gguf_parameters(self): + block_count = self.hparams["num_hidden_layers"] + head_size = self.hparams["head_size"] + hidden_size = self.hparams["hidden_size"] + layer_norm_eps = self.hparams["layer_norm_epsilon"] + rescale_every_n_layers = self.hparams["rescale_every"] + intermediate_size = self.hparams["intermediate_size"] if self.hparams["intermediate_size"] is not None else int((hidden_size * 3.5) // 32 * 32) + time_mix_extra_dim = 64 if hidden_size == 4096 else 32 + time_decay_extra_dim = 128 if hidden_size == 4096 else 64 + + # RWKV isn't context limited + self.gguf_writer.add_context_length(1048576) + self.gguf_writer.add_embedding_length(hidden_size) + self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_layer_norm_eps(layer_norm_eps) + self.gguf_writer.add_rescale_every_n_layers(rescale_every_n_layers) + self.gguf_writer.add_wkv_head_size(head_size) + self.gguf_writer.add_time_mix_extra_dim(time_mix_extra_dim) + self.gguf_writer.add_time_decay_extra_dim(time_decay_extra_dim) + self.gguf_writer.add_feed_forward_length(intermediate_size) + self.gguf_writer.add_file_type(self.ftype) + + # required by llama.cpp, unused + self.gguf_writer.add_head_count(0) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + new_name = self.map_tensor_name(name) + + if not (new_name.endswith(".weight") or new_name.endswith(".bias")): + new_name += ".weight" + + if new_name.endswith("time_mix_w1.weight") or new_name.endswith("time_mix_decay_w1.weight") or new_name.endswith("time_mix_decay_w2.weight"): + data_torch = data_torch.transpose(0, 1) + + if new_name.endswith("time_mix_w2.weight"): + data_torch = data_torch.permute(0, 2, 1) + + rescale_every_n_layers = self.hparams["rescale_every"] + if rescale_every_n_layers > 0: + if new_name.endswith("time_mix_output.weight") or new_name.endswith("channel_mix_value.weight"): + data_torch = data_torch.div_(2 ** int(bid // rescale_every_n_layers)) + + yield (new_name, data_torch) + + @Model.register("MambaForCausalLM", "MambaLMHeadModel", "FalconMambaForCausalLM") class MambaModel(Model): model_arch = gguf.MODEL_ARCH.MAMBA diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 5233a9995..3fb680360 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -514,6 +514,7 @@ extern "C" { GGML_OP_WIN_UNPART, GGML_OP_GET_REL_POS, GGML_OP_ADD_REL_POS, + GGML_OP_RWKV_WKV, GGML_OP_UNARY, @@ -548,6 +549,7 @@ extern "C" { GGML_UNARY_OP_SILU, GGML_UNARY_OP_HARDSWISH, GGML_UNARY_OP_HARDSIGMOID, + GGML_UNARY_OP_EXP, GGML_UNARY_OP_COUNT, }; @@ -1165,6 +1167,14 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a); + GGML_API struct ggml_tensor * ggml_exp( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_exp_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + // normalize along rows GGML_API struct ggml_tensor * ggml_norm( struct ggml_context * ctx, @@ -1913,6 +1923,15 @@ extern "C" { struct ggml_tensor * pw, struct ggml_tensor * ph); + GGML_API struct ggml_tensor * ggml_rwkv_wkv( + struct ggml_context * ctx, + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor * r, + struct ggml_tensor * tf, + struct ggml_tensor * td, + struct ggml_tensor * state); + // custom operators typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index dc6cdca0b..83140d757 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -2422,6 +2422,7 @@ inline static void ggml_vec_sigmoid_f32 (const int n, float * y, const float * x // TODO: optimize performance inline static void ggml_vec_hardswish_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); } inline static void ggml_vec_hardsigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); } +inline static void ggml_vec_exp_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = expf(x[i]); } static const float GELU_COEF_A = 0.044715f; static const float GELU_QUICK_COEF = -1.702f; @@ -2932,6 +2933,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "WIN_UNPART", "GET_REL_POS", "ADD_REL_POS", + "RWKV_WKV", "UNARY", @@ -2950,7 +2952,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CROSS_ENTROPY_LOSS_BACK", }; -static_assert(GGML_OP_COUNT == 78, "GGML_OP_COUNT != 78"); +static_assert(GGML_OP_COUNT == 79, "GGML_OP_COUNT != 79"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -3024,6 +3026,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "win_unpart(x)", "get_rel_pos(x)", "add_rel_pos(x)", + "rwkv_wkv(k, v, r, tf, td, s)", "unary(x)", @@ -3042,7 +3045,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "cross_entropy_loss_back(x,y)", }; -static_assert(GGML_OP_COUNT == 78, "GGML_OP_COUNT != 78"); +static_assert(GGML_OP_COUNT == 79, "GGML_OP_COUNT != 79"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -3061,9 +3064,10 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = { "SILU", "HARDSWISH", "HARDSIGMOID", + "EXP", }; -static_assert(GGML_UNARY_OP_COUNT == 13, "GGML_UNARY_OP_COUNT != 13"); +static_assert(GGML_UNARY_OP_COUNT == 14, "GGML_UNARY_OP_COUNT != 14"); static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN"); @@ -5464,6 +5468,19 @@ struct ggml_tensor * ggml_hardsigmoid( return ggml_unary(ctx, a, GGML_UNARY_OP_HARDSIGMOID); } +// ggml exp +struct ggml_tensor * ggml_exp( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary(ctx, a, GGML_UNARY_OP_EXP); +} + +struct ggml_tensor * ggml_exp_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_EXP); +} + // ggml_norm static struct ggml_tensor * ggml_norm_impl( @@ -7727,6 +7744,59 @@ struct ggml_tensor * ggml_add_rel_pos_inplace( return ggml_add_rel_pos_impl(ctx, a, pw, ph, true); } +// ggml_rwkv_wkv + +struct ggml_tensor * ggml_rwkv_wkv( + struct ggml_context * ctx, + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor * r, + struct ggml_tensor * tf, + struct ggml_tensor * td, + struct ggml_tensor * state) { + GGML_ASSERT(ggml_is_contiguous(k)); + GGML_ASSERT(ggml_is_contiguous(v)); + GGML_ASSERT(ggml_is_contiguous(r)); + GGML_ASSERT(ggml_is_contiguous(tf)); + GGML_ASSERT(ggml_is_contiguous(td)); + GGML_ASSERT(ggml_is_contiguous(state)); + + const int64_t S = k->ne[0]; + const int64_t H = k->ne[2]; + const int64_t n_tokens = k->ne[3]; + const int64_t n_seqs = state->ne[1]; + { + GGML_ASSERT(k->ne[1] == 1); + GGML_ASSERT(v->ne[0] == 1 && v->ne[1] == S && v->ne[2] == H && v->ne[3] == n_tokens); + GGML_ASSERT(r->ne[0] == 1 && r->ne[1] == S && r->ne[2] == H && r->ne[3] == n_tokens); + // TODO: RWKV v4 and v5 + GGML_ASSERT(td->ne[0] == 1 && td->ne[1] == S && td->ne[2] == H && td->ne[3] == n_tokens); + GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs); + } + + bool is_node = false; + + if (k->grad || v->grad || r->grad || tf->grad || td->grad || state->grad) { + GGML_ABORT("fatal error"); // TODO: implement backward + is_node = true; + } + + // concat output and new_state + const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + + result->op = GGML_OP_RWKV_WKV; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = k; + result->src[1] = v; + result->src[2] = r; + result->src[3] = tf; + result->src[4] = td; + result->src[5] = state; + + return result; +} + // ggml_unary static struct ggml_tensor * ggml_unary_impl( @@ -12126,6 +12196,48 @@ static void ggml_compute_forward_hardsigmoid( } } +static void ggml_compute_forward_exp_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + assert(ggml_is_contiguous_1(src0)); + assert(ggml_is_contiguous_1(dst)); + assert(ggml_are_same_shape(src0, dst)); + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + for (int i = 0; i < n; i++) { + ggml_vec_exp_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void ggml_compute_forward_exp( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_exp_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + // ggml_compute_forward_norm @@ -16704,6 +16816,10 @@ static void ggml_compute_forward_unary( { ggml_compute_forward_hardsigmoid(params, dst); } break; + case GGML_UNARY_OP_EXP: + { + ggml_compute_forward_exp(params, dst); + } break; default: { GGML_ABORT("fatal error"); @@ -16839,6 +16955,96 @@ static void ggml_compute_forward_add_rel_pos( } } +// ggml_compute_forward_rwkv_wkv + +static void ggml_compute_forward_rwkv_wkv_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + const size_t T = dst->src[1]->ne[3]; + const size_t C = dst->ne[0]; + const size_t H = dst->src[1]->ne[2]; + const size_t n_seqs = dst->src[5]->ne[1]; + + float * dst_data = (float *) dst->data; + float * state = ((float *) dst->data) + C * T; + + if (params->ith != 0) { + return; + } + + memset(dst_data, 0, T * C * sizeof(float)); + + float * k = (float *) dst->src[0]->data; + float * v = (float *) dst->src[1]->data; + float * r = (float *) dst->src[2]->data; + float * time_faaaa = (float *) dst->src[3]->data; + float * time_decay = (float *) dst->src[4]->data; + + size_t t_stride = H * (C / H); + + size_t h_stride = C / H; + size_t h_stride_2d = (C / H) * (C / H); + + // basically fused operations: + // dst = r @ (time_faaaa * (k @ v) + state), + // state = time_decay * state + (k @ v), + // recursive through each token + for (size_t t = 0; t < T; t++) { + size_t t_offset = t * t_stride; + size_t state_offset = (C / H) * C * (t / (T / n_seqs)); + float * state_cur = state + state_offset; + float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset; + + for (size_t h = 0; h < H; h++) { + size_t h_offset = h * h_stride; + size_t t_h_offset = t_offset + h_offset; + size_t h_2d_offset = h * h_stride_2d; + + for (size_t i = 0; i < C / H; i++) { + size_t t_h_i_offset = t_h_offset + i; + size_t h_i_offset = h_offset + i; + size_t h_2d_i_offset = h_2d_offset + i * h_stride; + + float k_val = k[t_h_i_offset]; + float r_val = r[t_h_i_offset]; + float time_faaaa_val = time_faaaa[h_i_offset]; + // RWKV v6: different time_decay for each token. + float time_decay_val = time_decay[t_h_i_offset]; + + for (size_t j = 0; j < C / H; j ++) { + size_t t_h_j_offset = t_h_offset + j; + size_t h_2d_i_j_offset = h_2d_i_offset + j; + + float v_val = v[t_h_j_offset]; + float kv_val = v_val * k_val; + float prev_state_val = state_prev[h_2d_i_j_offset]; + float temp_val = kv_val * time_faaaa_val + prev_state_val; + dst_data[t_h_j_offset] += temp_val * r_val; + state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val; + } + } + } + } +} + +static void ggml_compute_forward_rwkv_wkv( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_rwkv_wkv_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + // ggml_compute_forward_map_unary static void ggml_compute_forward_map_unary_f32( @@ -17490,6 +17696,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_add_rel_pos(params, tensor); } break; + case GGML_OP_RWKV_WKV: + { + ggml_compute_forward_rwkv_wkv(params, tensor); + } break; case GGML_OP_MAP_UNARY: { ggml_unary_op_f32_t fun; @@ -18607,12 +18817,22 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor zero_table); } } break; + case GGML_UNARY_OP_EXP: + { + if (src0->grad) { + src0->grad = ggml_add_or_set(ctx, + src0->grad, + ggml_mul(ctx, tensor, tensor->grad), + zero_table); + } + } break; default: GGML_ABORT("fatal error"); } } break; case GGML_OP_GET_REL_POS: case GGML_OP_ADD_REL_POS: + case GGML_OP_RWKV_WKV: case GGML_OP_MAP_UNARY: case GGML_OP_MAP_BINARY: case GGML_OP_MAP_CUSTOM1_F32: @@ -19036,6 +19256,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_UNARY_OP_SIGMOID: case GGML_UNARY_OP_HARDSWISH: case GGML_UNARY_OP_HARDSIGMOID: + case GGML_UNARY_OP_EXP: { n_tasks = 1; } break; @@ -19127,6 +19348,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_WIN_PART: case GGML_OP_WIN_UNPART: case GGML_OP_GET_REL_POS: + case GGML_OP_RWKV_WKV: case GGML_OP_MAP_UNARY: case GGML_OP_MAP_BINARY: case GGML_OP_MAP_CUSTOM1_F32: diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index b55effa99..a48c4fb67 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -94,6 +94,9 @@ class Keys: DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id" ATTN_LOGIT_SOFTCAPPING = "{arch}.attn_logit_softcapping" FINAL_LOGIT_SOFTCAPPING = "{arch}.final_logit_softcapping" + RESCALE_EVERY_N_LAYERS = "{arch}.rescale_every_n_layers" + TIME_MIX_EXTRA_DIM = "{arch}.time_mix_extra_dim" + TIME_DECAY_EXTRA_DIM = "{arch}.time_decay_extra_dim" class Attention: HEAD_COUNT = "{arch}.attention.head_count" @@ -132,6 +135,9 @@ class Keys: TIME_STEP_RANK = "{arch}.ssm.time_step_rank" DT_B_C_RMS = "{arch}.ssm.dt_b_c_rms" + class WKV: + HEAD_SIZE = "{arch}.wkv.head_size" + class Tokenizer: MODEL = "tokenizer.ggml.model" PRE = "tokenizer.ggml.pre" @@ -207,6 +213,7 @@ class MODEL_ARCH(IntEnum): GEMMA = auto() GEMMA2 = auto() STARCODER2 = auto() + RWKV6 = auto() MAMBA = auto() XVERSE = auto() COMMAND_R = auto() @@ -270,6 +277,29 @@ class MODEL_TENSOR(IntEnum): SSM_A = auto() SSM_D = auto() SSM_OUT = auto() + TIME_MIX_W1 = auto() + TIME_MIX_W2 = auto() + TIME_MIX_LERP_X = auto() + TIME_MIX_LERP_K = auto() + TIME_MIX_LERP_V = auto() + TIME_MIX_LERP_R = auto() + TIME_MIX_LERP_G = auto() + TIME_MIX_LERP_W = auto() + TIME_MIX_FIRST = auto() + TIME_MIX_DECAY = auto() + TIME_MIX_DECAY_W1 = auto() + TIME_MIX_DECAY_W2 = auto() + TIME_MIX_KEY = auto() + TIME_MIX_VALUE = auto() + TIME_MIX_RECEPTANCE = auto() + TIME_MIX_GATE = auto() + TIME_MIX_LN = auto() + TIME_MIX_OUTPUT = auto() + CHANNEL_MIX_LERP_K = auto() + CHANNEL_MIX_LERP_R = auto() + CHANNEL_MIX_KEY = auto() + CHANNEL_MIX_RECEPTANCE = auto() + CHANNEL_MIX_VALUE = auto() ATTN_Q_A = auto() ATTN_Q_B = auto() ATTN_KV_A_MQA = auto() @@ -337,6 +367,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.GEMMA: "gemma", MODEL_ARCH.GEMMA2: "gemma2", MODEL_ARCH.STARCODER2: "starcoder2", + MODEL_ARCH.RWKV6: "rwkv6", MODEL_ARCH.MAMBA: "mamba", MODEL_ARCH.XVERSE: "xverse", MODEL_ARCH.COMMAND_R: "command-r", @@ -355,87 +386,110 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { } TENSOR_NAMES: dict[MODEL_TENSOR, str] = { - MODEL_TENSOR.TOKEN_EMBD: "token_embd", - MODEL_TENSOR.TOKEN_EMBD_NORM: "token_embd_norm", - MODEL_TENSOR.TOKEN_TYPES: "token_types", - MODEL_TENSOR.POS_EMBD: "position_embd", - MODEL_TENSOR.OUTPUT_NORM: "output_norm", - MODEL_TENSOR.OUTPUT: "output", - MODEL_TENSOR.ROPE_FREQS: "rope_freqs", - MODEL_TENSOR.ROPE_FACTORS_LONG: "rope_factors_long", - MODEL_TENSOR.ROPE_FACTORS_SHORT: "rope_factors_short", - MODEL_TENSOR.ATTN_NORM: "blk.{bid}.attn_norm", - MODEL_TENSOR.ATTN_NORM_2: "blk.{bid}.attn_norm_2", - MODEL_TENSOR.ATTN_QKV: "blk.{bid}.attn_qkv", - MODEL_TENSOR.ATTN_Q: "blk.{bid}.attn_q", - MODEL_TENSOR.ATTN_K: "blk.{bid}.attn_k", - MODEL_TENSOR.ATTN_V: "blk.{bid}.attn_v", - MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn_output", - MODEL_TENSOR.ATTN_ROT_EMBD: "blk.{bid}.attn_rot_embd", - MODEL_TENSOR.ATTN_Q_NORM: "blk.{bid}.attn_q_norm", - MODEL_TENSOR.ATTN_K_NORM: "blk.{bid}.attn_k_norm", - MODEL_TENSOR.ATTN_OUT_NORM: "blk.{bid}.attn_output_norm", - MODEL_TENSOR.ATTN_POST_NORM: "blk.{bid}.post_attention_norm", - MODEL_TENSOR.FFN_GATE_INP: "blk.{bid}.ffn_gate_inp", - MODEL_TENSOR.FFN_GATE_INP_SHEXP: "blk.{bid}.ffn_gate_inp_shexp", - MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm", - MODEL_TENSOR.FFN_PRE_NORM: "blk.{bid}.ffn_norm", - MODEL_TENSOR.FFN_POST_NORM: "blk.{bid}.post_ffw_norm", - MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate", - MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down", - MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up", - MODEL_TENSOR.FFN_GATE_SHEXP: "blk.{bid}.ffn_gate_shexp", - MODEL_TENSOR.FFN_DOWN_SHEXP: "blk.{bid}.ffn_down_shexp", - MODEL_TENSOR.FFN_UP_SHEXP: "blk.{bid}.ffn_up_shexp", - MODEL_TENSOR.FFN_ACT: "blk.{bid}.ffn", - MODEL_TENSOR.FFN_NORM_EXP: "blk.{bid}.ffn_norm_exps", - MODEL_TENSOR.FFN_GATE_EXP: "blk.{bid}.ffn_gate_exps", - MODEL_TENSOR.FFN_DOWN_EXP: "blk.{bid}.ffn_down_exps", - MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up_exps", - MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm", - MODEL_TENSOR.SSM_IN: "blk.{bid}.ssm_in", - MODEL_TENSOR.SSM_CONV1D: "blk.{bid}.ssm_conv1d", - MODEL_TENSOR.SSM_X: "blk.{bid}.ssm_x", - MODEL_TENSOR.SSM_DT: "blk.{bid}.ssm_dt", - MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a", - MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d", - MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out", - MODEL_TENSOR.ATTN_Q_A: "blk.{bid}.attn_q_a", - MODEL_TENSOR.ATTN_Q_B: "blk.{bid}.attn_q_b", - MODEL_TENSOR.ATTN_KV_A_MQA: "blk.{bid}.attn_kv_a_mqa", - MODEL_TENSOR.ATTN_KV_B: "blk.{bid}.attn_kv_b", - MODEL_TENSOR.ATTN_Q_A_NORM: "blk.{bid}.attn_q_a_norm", - MODEL_TENSOR.ATTN_KV_A_NORM: "blk.{bid}.attn_kv_a_norm", - MODEL_TENSOR.ATTN_SUB_NORM: "blk.{bid}.attn_sub_norm", - MODEL_TENSOR.FFN_SUB_NORM: "blk.{bid}.ffn_sub_norm", - MODEL_TENSOR.DEC_ATTN_NORM: "dec.blk.{bid}.attn_norm", - MODEL_TENSOR.DEC_ATTN_Q: "dec.blk.{bid}.attn_q", - MODEL_TENSOR.DEC_ATTN_K: "dec.blk.{bid}.attn_k", - MODEL_TENSOR.DEC_ATTN_V: "dec.blk.{bid}.attn_v", - MODEL_TENSOR.DEC_ATTN_OUT: "dec.blk.{bid}.attn_o", - MODEL_TENSOR.DEC_ATTN_REL_B: "dec.blk.{bid}.attn_rel_b", - MODEL_TENSOR.DEC_CROSS_ATTN_NORM: "dec.blk.{bid}.cross_attn_norm", - MODEL_TENSOR.DEC_CROSS_ATTN_Q: "dec.blk.{bid}.cross_attn_q", - MODEL_TENSOR.DEC_CROSS_ATTN_K: "dec.blk.{bid}.cross_attn_k", - MODEL_TENSOR.DEC_CROSS_ATTN_V: "dec.blk.{bid}.cross_attn_v", - MODEL_TENSOR.DEC_CROSS_ATTN_OUT: "dec.blk.{bid}.cross_attn_o", - MODEL_TENSOR.DEC_CROSS_ATTN_REL_B: "dec.blk.{bid}.cross_attn_rel_b", - MODEL_TENSOR.DEC_FFN_NORM: "dec.blk.{bid}.ffn_norm", - MODEL_TENSOR.DEC_FFN_GATE: "dec.blk.{bid}.ffn_gate", - MODEL_TENSOR.DEC_FFN_DOWN: "dec.blk.{bid}.ffn_down", - MODEL_TENSOR.DEC_FFN_UP: "dec.blk.{bid}.ffn_up", - MODEL_TENSOR.DEC_OUTPUT_NORM: "dec.output_norm", - MODEL_TENSOR.ENC_ATTN_NORM: "enc.blk.{bid}.attn_norm", - MODEL_TENSOR.ENC_ATTN_Q: "enc.blk.{bid}.attn_q", - MODEL_TENSOR.ENC_ATTN_K: "enc.blk.{bid}.attn_k", - MODEL_TENSOR.ENC_ATTN_V: "enc.blk.{bid}.attn_v", - MODEL_TENSOR.ENC_ATTN_OUT: "enc.blk.{bid}.attn_o", - MODEL_TENSOR.ENC_ATTN_REL_B: "enc.blk.{bid}.attn_rel_b", - MODEL_TENSOR.ENC_FFN_NORM: "enc.blk.{bid}.ffn_norm", - MODEL_TENSOR.ENC_FFN_GATE: "enc.blk.{bid}.ffn_gate", - MODEL_TENSOR.ENC_FFN_DOWN: "enc.blk.{bid}.ffn_down", - MODEL_TENSOR.ENC_FFN_UP: "enc.blk.{bid}.ffn_up", - MODEL_TENSOR.ENC_OUTPUT_NORM: "enc.output_norm", + MODEL_TENSOR.TOKEN_EMBD: "token_embd", + MODEL_TENSOR.TOKEN_EMBD_NORM: "token_embd_norm", + MODEL_TENSOR.TOKEN_TYPES: "token_types", + MODEL_TENSOR.POS_EMBD: "position_embd", + MODEL_TENSOR.OUTPUT_NORM: "output_norm", + MODEL_TENSOR.OUTPUT: "output", + MODEL_TENSOR.ROPE_FREQS: "rope_freqs", + MODEL_TENSOR.ROPE_FACTORS_LONG: "rope_factors_long", + MODEL_TENSOR.ROPE_FACTORS_SHORT: "rope_factors_short", + MODEL_TENSOR.ATTN_NORM: "blk.{bid}.attn_norm", + MODEL_TENSOR.ATTN_NORM_2: "blk.{bid}.attn_norm_2", + MODEL_TENSOR.ATTN_QKV: "blk.{bid}.attn_qkv", + MODEL_TENSOR.ATTN_Q: "blk.{bid}.attn_q", + MODEL_TENSOR.ATTN_K: "blk.{bid}.attn_k", + MODEL_TENSOR.ATTN_V: "blk.{bid}.attn_v", + MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn_output", + MODEL_TENSOR.ATTN_ROT_EMBD: "blk.{bid}.attn_rot_embd", + MODEL_TENSOR.ATTN_Q_NORM: "blk.{bid}.attn_q_norm", + MODEL_TENSOR.ATTN_K_NORM: "blk.{bid}.attn_k_norm", + MODEL_TENSOR.ATTN_OUT_NORM: "blk.{bid}.attn_output_norm", + MODEL_TENSOR.ATTN_POST_NORM: "blk.{bid}.post_attention_norm", + MODEL_TENSOR.FFN_GATE_INP: "blk.{bid}.ffn_gate_inp", + MODEL_TENSOR.FFN_GATE_INP_SHEXP: "blk.{bid}.ffn_gate_inp_shexp", + MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm", + MODEL_TENSOR.FFN_PRE_NORM: "blk.{bid}.ffn_norm", + MODEL_TENSOR.FFN_POST_NORM: "blk.{bid}.post_ffw_norm", + MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate", + MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down", + MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up", + MODEL_TENSOR.FFN_GATE_SHEXP: "blk.{bid}.ffn_gate_shexp", + MODEL_TENSOR.FFN_DOWN_SHEXP: "blk.{bid}.ffn_down_shexp", + MODEL_TENSOR.FFN_UP_SHEXP: "blk.{bid}.ffn_up_shexp", + MODEL_TENSOR.FFN_ACT: "blk.{bid}.ffn", + MODEL_TENSOR.FFN_NORM_EXP: "blk.{bid}.ffn_norm_exps", + MODEL_TENSOR.FFN_GATE_EXP: "blk.{bid}.ffn_gate_exps", + MODEL_TENSOR.FFN_DOWN_EXP: "blk.{bid}.ffn_down_exps", + MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up_exps", + MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm", + MODEL_TENSOR.SSM_IN: "blk.{bid}.ssm_in", + MODEL_TENSOR.SSM_CONV1D: "blk.{bid}.ssm_conv1d", + MODEL_TENSOR.SSM_X: "blk.{bid}.ssm_x", + MODEL_TENSOR.SSM_DT: "blk.{bid}.ssm_dt", + MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a", + MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d", + MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out", + MODEL_TENSOR.TIME_MIX_W1: "blk.{bid}.time_mix_w1", + MODEL_TENSOR.TIME_MIX_W2: "blk.{bid}.time_mix_w2", + MODEL_TENSOR.TIME_MIX_LERP_X: "blk.{bid}.time_mix_lerp_x", + MODEL_TENSOR.TIME_MIX_LERP_K: "blk.{bid}.time_mix_lerp_k", + MODEL_TENSOR.TIME_MIX_LERP_V: "blk.{bid}.time_mix_lerp_v", + MODEL_TENSOR.TIME_MIX_LERP_R: "blk.{bid}.time_mix_lerp_r", + MODEL_TENSOR.TIME_MIX_LERP_G: "blk.{bid}.time_mix_lerp_g", + MODEL_TENSOR.TIME_MIX_LERP_W: "blk.{bid}.time_mix_lerp_w", + MODEL_TENSOR.TIME_MIX_FIRST: "blk.{bid}.time_mix_first", + MODEL_TENSOR.TIME_MIX_DECAY: "blk.{bid}.time_mix_decay", + MODEL_TENSOR.TIME_MIX_DECAY_W1: "blk.{bid}.time_mix_decay_w1", + MODEL_TENSOR.TIME_MIX_DECAY_W2: "blk.{bid}.time_mix_decay_w2", + MODEL_TENSOR.TIME_MIX_KEY: "blk.{bid}.time_mix_key", + MODEL_TENSOR.TIME_MIX_VALUE: "blk.{bid}.time_mix_value", + MODEL_TENSOR.TIME_MIX_RECEPTANCE: "blk.{bid}.time_mix_receptance", + MODEL_TENSOR.TIME_MIX_GATE: "blk.{bid}.time_mix_gate", + MODEL_TENSOR.TIME_MIX_LN: "blk.{bid}.time_mix_ln", + MODEL_TENSOR.TIME_MIX_OUTPUT: "blk.{bid}.time_mix_output", + MODEL_TENSOR.CHANNEL_MIX_LERP_K: "blk.{bid}.channel_mix_lerp_k", + MODEL_TENSOR.CHANNEL_MIX_LERP_R: "blk.{bid}.channel_mix_lerp_r", + MODEL_TENSOR.CHANNEL_MIX_KEY: "blk.{bid}.channel_mix_key", + MODEL_TENSOR.CHANNEL_MIX_RECEPTANCE: "blk.{bid}.channel_mix_receptance", + MODEL_TENSOR.CHANNEL_MIX_VALUE: "blk.{bid}.channel_mix_value", + MODEL_TENSOR.ATTN_Q_A: "blk.{bid}.attn_q_a", + MODEL_TENSOR.ATTN_Q_B: "blk.{bid}.attn_q_b", + MODEL_TENSOR.ATTN_KV_A_MQA: "blk.{bid}.attn_kv_a_mqa", + MODEL_TENSOR.ATTN_KV_B: "blk.{bid}.attn_kv_b", + MODEL_TENSOR.ATTN_Q_A_NORM: "blk.{bid}.attn_q_a_norm", + MODEL_TENSOR.ATTN_KV_A_NORM: "blk.{bid}.attn_kv_a_norm", + MODEL_TENSOR.ATTN_SUB_NORM: "blk.{bid}.attn_sub_norm", + MODEL_TENSOR.FFN_SUB_NORM: "blk.{bid}.ffn_sub_norm", + MODEL_TENSOR.DEC_ATTN_NORM: "dec.blk.{bid}.attn_norm", + MODEL_TENSOR.DEC_ATTN_Q: "dec.blk.{bid}.attn_q", + MODEL_TENSOR.DEC_ATTN_K: "dec.blk.{bid}.attn_k", + MODEL_TENSOR.DEC_ATTN_V: "dec.blk.{bid}.attn_v", + MODEL_TENSOR.DEC_ATTN_OUT: "dec.blk.{bid}.attn_o", + MODEL_TENSOR.DEC_ATTN_REL_B: "dec.blk.{bid}.attn_rel_b", + MODEL_TENSOR.DEC_CROSS_ATTN_NORM: "dec.blk.{bid}.cross_attn_norm", + MODEL_TENSOR.DEC_CROSS_ATTN_Q: "dec.blk.{bid}.cross_attn_q", + MODEL_TENSOR.DEC_CROSS_ATTN_K: "dec.blk.{bid}.cross_attn_k", + MODEL_TENSOR.DEC_CROSS_ATTN_V: "dec.blk.{bid}.cross_attn_v", + MODEL_TENSOR.DEC_CROSS_ATTN_OUT: "dec.blk.{bid}.cross_attn_o", + MODEL_TENSOR.DEC_CROSS_ATTN_REL_B: "dec.blk.{bid}.cross_attn_rel_b", + MODEL_TENSOR.DEC_FFN_NORM: "dec.blk.{bid}.ffn_norm", + MODEL_TENSOR.DEC_FFN_GATE: "dec.blk.{bid}.ffn_gate", + MODEL_TENSOR.DEC_FFN_DOWN: "dec.blk.{bid}.ffn_down", + MODEL_TENSOR.DEC_FFN_UP: "dec.blk.{bid}.ffn_up", + MODEL_TENSOR.DEC_OUTPUT_NORM: "dec.output_norm", + MODEL_TENSOR.ENC_ATTN_NORM: "enc.blk.{bid}.attn_norm", + MODEL_TENSOR.ENC_ATTN_Q: "enc.blk.{bid}.attn_q", + MODEL_TENSOR.ENC_ATTN_K: "enc.blk.{bid}.attn_k", + MODEL_TENSOR.ENC_ATTN_V: "enc.blk.{bid}.attn_v", + MODEL_TENSOR.ENC_ATTN_OUT: "enc.blk.{bid}.attn_o", + MODEL_TENSOR.ENC_ATTN_REL_B: "enc.blk.{bid}.attn_rel_b", + MODEL_TENSOR.ENC_FFN_NORM: "enc.blk.{bid}.ffn_norm", + MODEL_TENSOR.ENC_FFN_GATE: "enc.blk.{bid}.ffn_gate", + MODEL_TENSOR.ENC_FFN_DOWN: "enc.blk.{bid}.ffn_down", + MODEL_TENSOR.ENC_FFN_UP: "enc.blk.{bid}.ffn_up", + MODEL_TENSOR.ENC_OUTPUT_NORM: "enc.output_norm", } MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { @@ -856,6 +910,37 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, ], + MODEL_ARCH.RWKV6: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.TOKEN_EMBD_NORM, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_NORM_2, + MODEL_TENSOR.TIME_MIX_W1, + MODEL_TENSOR.TIME_MIX_W2, + MODEL_TENSOR.TIME_MIX_LERP_X, + MODEL_TENSOR.TIME_MIX_LERP_K, + MODEL_TENSOR.TIME_MIX_LERP_V, + MODEL_TENSOR.TIME_MIX_LERP_R, + MODEL_TENSOR.TIME_MIX_LERP_G, + MODEL_TENSOR.TIME_MIX_LERP_W, + MODEL_TENSOR.TIME_MIX_FIRST, + MODEL_TENSOR.TIME_MIX_DECAY, + MODEL_TENSOR.TIME_MIX_DECAY_W1, + MODEL_TENSOR.TIME_MIX_DECAY_W2, + MODEL_TENSOR.TIME_MIX_KEY, + MODEL_TENSOR.TIME_MIX_VALUE, + MODEL_TENSOR.TIME_MIX_RECEPTANCE, + MODEL_TENSOR.TIME_MIX_GATE, + MODEL_TENSOR.TIME_MIX_LN, + MODEL_TENSOR.TIME_MIX_OUTPUT, + MODEL_TENSOR.CHANNEL_MIX_LERP_K, + MODEL_TENSOR.CHANNEL_MIX_LERP_R, + MODEL_TENSOR.CHANNEL_MIX_KEY, + MODEL_TENSOR.CHANNEL_MIX_RECEPTANCE, + MODEL_TENSOR.CHANNEL_MIX_VALUE, + ], MODEL_ARCH.MAMBA: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index af3b98c67..3c95c2673 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -670,6 +670,18 @@ class GGUFWriter: def add_expert_weights_scale(self, value: float) -> None: self.add_float32(Keys.LLM.EXPERT_WEIGHTS_SCALE.format(arch=self.arch), value) + def add_rescale_every_n_layers(self, count: int) -> None: + self.add_uint32(Keys.LLM.RESCALE_EVERY_N_LAYERS.format(arch=self.arch), count) + + def add_time_mix_extra_dim(self, dim: int) -> None: + self.add_uint32(Keys.LLM.TIME_MIX_EXTRA_DIM.format(arch=self.arch), dim) + + def add_time_decay_extra_dim(self, dim: int) -> None: + self.add_uint32(Keys.LLM.TIME_DECAY_EXTRA_DIM.format(arch=self.arch), dim) + + def add_wkv_head_size(self, size: int) -> None: + self.add_uint32(Keys.WKV.HEAD_SIZE.format(arch=self.arch), size) + def add_layer_norm_eps(self, value: float) -> None: self.add_float32(Keys.Attention.LAYERNORM_EPS.format(arch=self.arch), value) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index a4f185c06..bc9a13ee5 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -27,6 +27,7 @@ class TensorNameMap: "embedding.word_embeddings", # chatglm "transformer.token_embeddings", # openelm "shared", # t5 + "rwkv.embeddings", # rwkv ), # Token type embeddings @@ -40,6 +41,7 @@ class TensorNameMap: "embeddings.LayerNorm", # bert "emb_ln", # nomic-bert "transformer.norm", # openelm + "rwkv.blocks.0.pre_ln", # rwkv ), # Position embeddings @@ -57,6 +59,7 @@ class TensorNameMap: "word_embeddings_for_head", # persimmon "lm_head.linear", # phi2 "output_layer", # chatglm + "head", # rwkv ), # Output norm @@ -76,6 +79,7 @@ class TensorNameMap: "encoder.final_layernorm", # chatglm "transformer.norm", # openelm "model.norm", # nemotron + "rwkv.ln_out", # rwkv ), # Rope frequencies @@ -108,12 +112,14 @@ class TensorNameMap: "transformer.blocks.{bid}.norm_attn_norm.norm_1", # dbrx "encoder.layers.{bid}.input_layernorm", # chatglm "transformer.layers.{bid}.attn_norm", # openelm + "rwkv.blocks.{bid}.ln1", # rwkv ), # Attention norm 2 MODEL_TENSOR.ATTN_NORM_2: ( - "transformer.h.{bid}.ln_attn", # falcon40b + "transformer.h.{bid}.ln_attn", # falcon40b "encoder.layer.{bid}.layer_norm_1", # jina-v2-code + "rwkv.blocks.{bid}.ln2", # rwkv ), # Attention query-key-value @@ -434,6 +440,98 @@ class TensorNameMap: "backbone.layers.{bid}.mixer.out_proj", ), + MODEL_TENSOR.TIME_MIX_W1: ( + "rwkv.blocks.{bid}.attention.time_maa_w1", # rwkv v6 + ), + + MODEL_TENSOR.TIME_MIX_W2: ( + "rwkv.blocks.{bid}.attention.time_maa_w2", # rwkv v6 + ), + + MODEL_TENSOR.TIME_MIX_LERP_X: ( + "rwkv.blocks.{bid}.attention.time_maa_x", # rwkv v6 + ), + + MODEL_TENSOR.TIME_MIX_LERP_K: ( + "rwkv.blocks.{bid}.attention.time_maa_k", # rwkv v6 + ), + + MODEL_TENSOR.TIME_MIX_LERP_V: ( + "rwkv.blocks.{bid}.attention.time_maa_v", # rwkv v6 + ), + + MODEL_TENSOR.TIME_MIX_LERP_R: ( + "rwkv.blocks.{bid}.attention.time_maa_r", # rwkv v6 + ), + + MODEL_TENSOR.TIME_MIX_LERP_G: ( + "rwkv.blocks.{bid}.attention.time_maa_g", # rwkv v6 + ), + + MODEL_TENSOR.TIME_MIX_LERP_W: ( + "rwkv.blocks.{bid}.attention.time_maa_w", # rwkv v6 + ), + + MODEL_TENSOR.TIME_MIX_FIRST: ( + "rwkv.blocks.{bid}.attention.time_faaaa", # rwkv v6 + ), + + MODEL_TENSOR.TIME_MIX_DECAY: ( + "rwkv.blocks.{bid}.attention.time_decay", # rwkv v6 + ), + + MODEL_TENSOR.TIME_MIX_DECAY_W1: ( + "rwkv.blocks.{bid}.attention.time_decay_w1", # rwkv v6 + ), + + MODEL_TENSOR.TIME_MIX_DECAY_W2: ( + "rwkv.blocks.{bid}.attention.time_decay_w2", # rwkv v6 + ), + + MODEL_TENSOR.TIME_MIX_KEY: ( + "rwkv.blocks.{bid}.attention.key", # rwkv + ), + + MODEL_TENSOR.TIME_MIX_VALUE: ( + "rwkv.blocks.{bid}.attention.value", # rwkv + ), + + MODEL_TENSOR.TIME_MIX_RECEPTANCE: ( + "rwkv.blocks.{bid}.attention.receptance", # rwkv + ), + + MODEL_TENSOR.TIME_MIX_GATE: ( + "rwkv.blocks.{bid}.attention.gate", # rwkv + ), + + MODEL_TENSOR.TIME_MIX_LN: ( + "rwkv.blocks.{bid}.attention.ln_x", # rwkv + ), + + MODEL_TENSOR.TIME_MIX_OUTPUT: ( + "rwkv.blocks.{bid}.attention.output", # rwkv + ), + + MODEL_TENSOR.CHANNEL_MIX_LERP_K: ( + "rwkv.blocks.{bid}.feed_forward.time_maa_k", # rwkv v6 + ), + + MODEL_TENSOR.CHANNEL_MIX_LERP_R: ( + "rwkv.blocks.{bid}.feed_forward.time_maa_r", # rwkv v6 + ), + + MODEL_TENSOR.CHANNEL_MIX_KEY: ( + "rwkv.blocks.{bid}.feed_forward.key", # rwkv + ), + + MODEL_TENSOR.CHANNEL_MIX_RECEPTANCE: ( + "rwkv.blocks.{bid}.feed_forward.receptance", # rwkv + ), + + MODEL_TENSOR.CHANNEL_MIX_VALUE: ( + "rwkv.blocks.{bid}.feed_forward.value", # rwkv + ), + MODEL_TENSOR.ATTN_Q_A: ( "model.layers.{bid}.self_attn.q_a_proj", # deepseek2 ), diff --git a/include/llama.h b/include/llama.h index f2e701602..bfc37e88b 100644 --- a/include/llama.h +++ b/include/llama.h @@ -66,6 +66,7 @@ extern "C" { LLAMA_VOCAB_TYPE_BPE = 2, // GPT-2 tokenizer based on byte-level BPE LLAMA_VOCAB_TYPE_WPM = 3, // BERT tokenizer based on WordPiece LLAMA_VOCAB_TYPE_UGM = 4, // T5 tokenizer based on Unigram + LLAMA_VOCAB_TYPE_RWKV = 5, // RWKV tokenizer based on greedy tokenization }; // pre-tokenization types diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index dc4138cbc..2c007477e 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -58,17 +58,17 @@ struct naive_trie { auto res = children.find(c); if (res != children.end()) { return res->second.get_longest_prefix(key, len, offset + 1); - } else { - return std::make_pair(key, offset); } + + return std::make_pair(key, offset); } - struct naive_trie * traverse(const char c) { + const struct naive_trie * traverse(const char c) const { auto res = children.find(c); if (res != children.end()) { return &res->second; - } else { - return NULL; } + + return NULL; } std::map children; bool has_value; @@ -843,7 +843,7 @@ struct llm_tokenizer_ugm { // traverse the token matcher trie to find a matching token bool single_codepoint_token_found = false; const struct best_tokenization & current_best = tokenization_results[input_offset]; - struct naive_trie * node = token_matcher.traverse(normalized[prefix_offset++]); + const struct naive_trie * node = token_matcher.traverse(normalized[prefix_offset++]); while (prefix_offset <= input_len && node != NULL) { // check if we found valid token in prefix @@ -1097,6 +1097,111 @@ private: struct naive_trie token_matcher; }; +// +// RWKV tokenizer +// + +static std::vector llama_unescape_rwkv_token(const std::string & escaped) { + std::vector output; + output.reserve(escaped.size()); + + // Parser state + bool escaping = false; + uint8_t hex_remaining = 0; + uint8_t hex_acc = 0; + + // Step through characters, performing parsing + for (const char & c : escaped) { + // If we're parsing a hex code, interpret the next character + if (hex_remaining != 0) { + uint8_t value = (c >= 'a') ? (c - 'a' + 10) : (c - '0'); + hex_acc = (hex_acc << 4) + value; + + hex_remaining -= 1; + if (hex_remaining == 0) { + output.push_back(hex_acc); + hex_acc = 0; + } + + continue; + } + + // If we got an escape character, interpret it + if (escaping) { + if (c == 't') { + output.push_back('\t'); + } else if (c == 'n') { + output.push_back('\n'); + } else if (c == 'r') { + output.push_back('\r'); + } else if (c == 'x') { + hex_remaining = 2; + } else { + output.push_back(c); + } + + escaping = false; + continue; + } + + if (c == '\\') { + escaping = true; + continue; + } + + output.push_back(c); + } + + return output; +} + +struct llm_tokenizer_rwkv { + llm_tokenizer_rwkv(const llama_vocab & vocab): vocab(vocab) { + // RWKV supports arbitrary byte tokens, but the vocab struct only supports string tokens. + // For now, we decode the vocab here into the lookup we'll use for tokenization. + + // build trie + for (unsigned int id = 0; id < vocab.id_to_token.size(); ++id) { + const auto & token = vocab.id_to_token[id]; + const auto data = llama_unescape_rwkv_token(token.text); + token_matcher.insert((const char *) data.data(), data.size(), id); + } + } + + void tokenize(const std::string & text, std::vector & output) { + uint32_t position = 0; + + while (position < text.size()) { + const struct naive_trie * node = token_matcher.traverse(text[position]); + if (node == NULL) { + // no matching token found, add unknown token + output.push_back(vocab.special_unk_id); + position += 1; + continue; + } + + // traverse the trie to find the longest matching token + uint32_t token_id = 0; + uint32_t token_length = 0; + while (node != NULL) { + if (node->has_value) { + token_id = node->value; + token_length = position + 1; + } + node = node->traverse(text[++position]); + } + + // add the longest matching token + output.push_back(token_id); + position = token_length; + } + } + + const llama_vocab & vocab; + + struct naive_trie token_matcher; +}; + // // (de-) tokenize // @@ -1401,6 +1506,23 @@ std::vector llama_tokenize_internal(const llama_vocab & vocab, output.push_back(vocab.special_eos_id); } } break; + case LLAMA_VOCAB_TYPE_RWKV: + { + for (const auto & fragment : fragment_buffer) { + if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { + auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length); + +#ifdef PRETOKENIZERDEBUG + LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str()); +#endif + + llm_tokenizer_rwkv tokenizer(vocab); + tokenizer.tokenize(raw_text, output); + } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN) + output.push_back(fragment.token); + } + } + } break; case LLAMA_VOCAB_TYPE_NONE: GGML_ABORT("fatal error"); } @@ -1616,6 +1738,17 @@ int32_t llama_token_to_piece_impl(const struct llama_vocab & vocab, llama_token } break; } + case LLAMA_VOCAB_TYPE_RWKV: { + std::vector result = llama_unescape_rwkv_token(token_text); + + // If we don't have enough space, return an error + if (result.size() > (size_t)length) { + return -(int)result.size(); + } + + memcpy(buf, result.data(), result.size()); + return (int)result.size(); + } default: GGML_ABORT("fatal error"); } diff --git a/src/llama.cpp b/src/llama.cpp index 2274296b4..2113c72f3 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -212,6 +212,7 @@ enum llm_arch { LLM_ARCH_JAIS, LLM_ARCH_NEMOTRON, LLM_ARCH_EXAONE, + LLM_ARCH_RWKV6, LLM_ARCH_UNKNOWN, }; @@ -259,6 +260,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_JAIS, "jais" }, { LLM_ARCH_NEMOTRON, "nemotron" }, { LLM_ARCH_EXAONE, "exaone" }, + { LLM_ARCH_RWKV6, "rwkv6" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -295,6 +297,9 @@ enum llm_kv { LLM_KV_DECODER_START_TOKEN_ID, LLM_KV_ATTN_LOGIT_SOFTCAPPING, LLM_KV_FINAL_LOGIT_SOFTCAPPING, + LLM_KV_RESCALE_EVERY_N_LAYERS, + LLM_KV_TIME_MIX_EXTRA_DIM, + LLM_KV_TIME_DECAY_EXTRA_DIM, LLM_KV_ATTENTION_HEAD_COUNT, LLM_KV_ATTENTION_HEAD_COUNT_KV, @@ -330,6 +335,8 @@ enum llm_kv { LLM_KV_SSM_TIME_STEP_RANK, LLM_KV_SSM_DT_B_C_RMS, + LLM_KV_WKV_HEAD_SIZE, + LLM_KV_TOKENIZER_MODEL, LLM_KV_TOKENIZER_PRE, LLM_KV_TOKENIZER_LIST, @@ -389,11 +396,14 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_EXPERT_USED_COUNT, "%s.expert_used_count" }, { LLM_KV_EXPERT_SHARED_COUNT, "%s.expert_shared_count" }, { LLM_KV_EXPERT_WEIGHTS_SCALE, "%s.expert_weights_scale" }, - { LLM_KV_POOLING_TYPE , "%s.pooling_type" }, + { LLM_KV_POOLING_TYPE, "%s.pooling_type" }, { LLM_KV_LOGIT_SCALE, "%s.logit_scale" }, { LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" }, { LLM_KV_ATTN_LOGIT_SOFTCAPPING, "%s.attn_logit_softcapping" }, { LLM_KV_FINAL_LOGIT_SOFTCAPPING, "%s.final_logit_softcapping" }, + { LLM_KV_RESCALE_EVERY_N_LAYERS, "%s.rescale_every_n_layers" }, + { LLM_KV_TIME_MIX_EXTRA_DIM, "%s.time_mix_extra_dim" }, + { LLM_KV_TIME_DECAY_EXTRA_DIM, "%s.time_decay_extra_dim" }, { LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" }, { LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" }, @@ -429,6 +439,8 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_SSM_TIME_STEP_RANK, "%s.ssm.time_step_rank" }, { LLM_KV_SSM_DT_B_C_RMS, "%s.ssm.dt_b_c_rms" }, + { LLM_KV_WKV_HEAD_SIZE, "%s.wkv.head_size" }, + { LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" }, { LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" }, { LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" }, @@ -518,6 +530,29 @@ enum llm_tensor { LLM_TENSOR_SSM_A, LLM_TENSOR_SSM_D, LLM_TENSOR_SSM_OUT, + LLM_TENSOR_TIME_MIX_W1, + LLM_TENSOR_TIME_MIX_W2, + LLM_TENSOR_TIME_MIX_LERP_X, + LLM_TENSOR_TIME_MIX_LERP_W, + LLM_TENSOR_TIME_MIX_LERP_K, + LLM_TENSOR_TIME_MIX_LERP_V, + LLM_TENSOR_TIME_MIX_LERP_R, + LLM_TENSOR_TIME_MIX_LERP_G, + LLM_TENSOR_TIME_MIX_FIRST, + LLM_TENSOR_TIME_MIX_DECAY, + LLM_TENSOR_TIME_MIX_DECAY_W1, + LLM_TENSOR_TIME_MIX_DECAY_W2, + LLM_TENSOR_TIME_MIX_KEY, + LLM_TENSOR_TIME_MIX_VALUE, + LLM_TENSOR_TIME_MIX_RECEPTANCE, + LLM_TENSOR_TIME_MIX_GATE, + LLM_TENSOR_TIME_MIX_LN, + LLM_TENSOR_TIME_MIX_OUTPUT, + LLM_TENSOR_CHANNEL_MIX_LERP_K, + LLM_TENSOR_CHANNEL_MIX_LERP_R, + LLM_TENSOR_CHANNEL_MIX_KEY, + LLM_TENSOR_CHANNEL_MIX_RECEPTANCE, + LLM_TENSOR_CHANNEL_MIX_VALUE, LLM_TENSOR_ATTN_Q_A, LLM_TENSOR_ATTN_Q_B, LLM_TENSOR_ATTN_KV_A_MQA, @@ -1339,6 +1374,40 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, + { + LLM_ARCH_RWKV6, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_NORM_2, "blk.%d.attn_norm_2" }, + { LLM_TENSOR_TIME_MIX_W1, "blk.%d.time_mix_w1" }, + { LLM_TENSOR_TIME_MIX_W2, "blk.%d.time_mix_w2" }, + { LLM_TENSOR_TIME_MIX_LERP_X, "blk.%d.time_mix_lerp_x" }, + { LLM_TENSOR_TIME_MIX_LERP_W, "blk.%d.time_mix_lerp_w" }, + { LLM_TENSOR_TIME_MIX_LERP_K, "blk.%d.time_mix_lerp_k" }, + { LLM_TENSOR_TIME_MIX_LERP_V, "blk.%d.time_mix_lerp_v" }, + { LLM_TENSOR_TIME_MIX_LERP_R, "blk.%d.time_mix_lerp_r" }, + { LLM_TENSOR_TIME_MIX_LERP_G, "blk.%d.time_mix_lerp_g" }, + { LLM_TENSOR_TIME_MIX_FIRST, "blk.%d.time_mix_first" }, + { LLM_TENSOR_TIME_MIX_DECAY, "blk.%d.time_mix_decay" }, + { LLM_TENSOR_TIME_MIX_DECAY_W1, "blk.%d.time_mix_decay_w1" }, + { LLM_TENSOR_TIME_MIX_DECAY_W2, "blk.%d.time_mix_decay_w2" }, + { LLM_TENSOR_TIME_MIX_KEY, "blk.%d.time_mix_key" }, + { LLM_TENSOR_TIME_MIX_VALUE, "blk.%d.time_mix_value" }, + { LLM_TENSOR_TIME_MIX_RECEPTANCE, "blk.%d.time_mix_receptance" }, + { LLM_TENSOR_TIME_MIX_GATE, "blk.%d.time_mix_gate" }, + { LLM_TENSOR_TIME_MIX_LN, "blk.%d.time_mix_ln" }, + { LLM_TENSOR_TIME_MIX_OUTPUT, "blk.%d.time_mix_output" }, + { LLM_TENSOR_CHANNEL_MIX_LERP_K, "blk.%d.channel_mix_lerp_k" }, + { LLM_TENSOR_CHANNEL_MIX_LERP_R, "blk.%d.channel_mix_lerp_r" }, + { LLM_TENSOR_CHANNEL_MIX_KEY, "blk.%d.channel_mix_key" }, + { LLM_TENSOR_CHANNEL_MIX_VALUE, "blk.%d.channel_mix_value" }, + { LLM_TENSOR_CHANNEL_MIX_RECEPTANCE, "blk.%d.channel_mix_receptance" }, + }, + }, { LLM_ARCH_UNKNOWN, { @@ -2151,6 +2220,7 @@ enum e_model { MODEL_1B, MODEL_1_3B, MODEL_1_4B, + MODEL_1_6B, MODEL_2B, MODEL_2_8B, MODEL_3B, @@ -2228,6 +2298,12 @@ struct llama_hparams { float f_attn_logit_softcapping = 50.0f; float f_final_logit_softcapping = 30.0f; + // for RWKV + uint32_t rescale_every_n_layers = 0; + uint32_t time_mix_extra_dim = 0; + uint32_t time_decay_extra_dim = 0; + uint32_t wkv_head_size = 0; + float rope_attn_factor = 1.0f; float rope_freq_base_train; float rope_freq_scale_train; @@ -2291,6 +2367,11 @@ struct llama_hparams { if (this->ssm_dt_rank != other.ssm_dt_rank) return true; if (this->ssm_dt_b_c_rms != other.ssm_dt_b_c_rms) return true; + if (this->rescale_every_n_layers != other.rescale_every_n_layers) return true; + if (this->time_mix_extra_dim != other.time_mix_extra_dim) return true; + if (this->time_decay_extra_dim != other.time_decay_extra_dim) return true; + if (this->wkv_head_size != other.wkv_head_size) return true; + if (this->dec_start_token_id != other.dec_start_token_id) return true; const float EPSILON = 1e-9f; @@ -2354,15 +2435,25 @@ struct llama_hparams { } uint32_t n_embd_k_s() const { // dimension of the rolling state embeddings - // corresponds to Mamba's conv_states size - // TODO: maybe support other convolution strides than 1 - // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed - return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner; + // corresponds to Mamba's conv_states size or RWKV's token_shift states size + if (wkv_head_size != 0) { + // for RWKV models + return 2 * n_embd; + } else { + // TODO: maybe support other convolution strides than 1 + // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed + return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner; + } } uint32_t n_embd_v_s() const { // dimension of the recurrent state embeddings - // corresponds to Mamba's ssm_states size - return ssm_d_state * ssm_d_inner; + if (wkv_head_size != 0) { + // corresponds to RWKV's wkv_states size + return n_embd * wkv_head_size; + } else { + // corresponds to Mamba's ssm_states size + return ssm_d_state * ssm_d_inner; + } } }; @@ -2501,6 +2592,36 @@ struct llama_layer { struct ggml_tensor * ssm_conv1d_b; struct ggml_tensor * ssm_dt_b; + // rwkv + struct ggml_tensor * time_mix_w1; + struct ggml_tensor * time_mix_w2; + struct ggml_tensor * time_mix_lerp_x; + struct ggml_tensor * time_mix_lerp_w; + struct ggml_tensor * time_mix_lerp_k; + struct ggml_tensor * time_mix_lerp_v; + struct ggml_tensor * time_mix_lerp_r; + struct ggml_tensor * time_mix_lerp_g; + + struct ggml_tensor * time_mix_first; + struct ggml_tensor * time_mix_decay; + struct ggml_tensor * time_mix_decay_w1; + struct ggml_tensor * time_mix_decay_w2; + struct ggml_tensor * time_mix_key; + struct ggml_tensor * time_mix_value; + struct ggml_tensor * time_mix_receptance; + struct ggml_tensor * time_mix_gate; + + struct ggml_tensor * time_mix_ln; + struct ggml_tensor * time_mix_ln_b; + struct ggml_tensor * time_mix_output; + + struct ggml_tensor * channel_mix_lerp_k; + struct ggml_tensor * channel_mix_lerp_r; + + struct ggml_tensor * channel_mix_key; + struct ggml_tensor * channel_mix_receptance; + struct ggml_tensor * channel_mix_value; + // long rope factors struct ggml_tensor * rope_long = nullptr; struct ggml_tensor * rope_short = nullptr; @@ -3426,7 +3547,7 @@ static bool llama_kv_cache_find_slot( const uint32_t n_seq_tokens = batch.n_seq_tokens; if (cache.recurrent) { - // For recurrent state architectures (like Mamba), + // For recurrent state architectures (like Mamba or RWKV), // each cache cell can store the state for a whole sequence. // A slot should be always be contiguous. @@ -3675,7 +3796,7 @@ static bool llama_kv_cache_seq_rm( if (p0 < 0) p0 = 0; if (p1 < 0) p1 = std::numeric_limits::max(); - // models like Mamba can't have a state partially erased + // models like Mamba or RWKV can't have a state partially erased if (cache.recurrent) { if (seq_id >= (int64_t) cache.size) { // could be fatal @@ -3811,7 +3932,7 @@ static void llama_kv_cache_seq_add( if (p0 == p1) return; if (cache.recurrent) { - // for Mamba-like models, only the pos needs to be shifted + // for Mamba-like or RWKV models, only the pos needs to be shifted if (0 <= seq_id && seq_id < (int64_t) cache.size) { const int32_t tail_id = cache.cells[seq_id].tail; if (tail_id >= 0) { @@ -3860,7 +3981,7 @@ static void llama_kv_cache_seq_div( if (p0 == p1) return; if (cache.recurrent) { - // for Mamba-like models, only the pos needs to be changed + // for Mamba-like or RWKV models, only the pos needs to be changed if (0 <= seq_id && seq_id < (int64_t) cache.size) { const int32_t tail_id = cache.cells[seq_id].tail; if (tail_id >= 0) { @@ -5051,6 +5172,7 @@ static const char * llama_model_type_name(e_model type) { case MODEL_1B: return "1B"; case MODEL_1_3B: return "1.3B"; case MODEL_1_4B: return "1.4B"; + case MODEL_1_6B: return "1.6B"; case MODEL_2B: return "2B"; case MODEL_2_8B: return "2.8B"; case MODEL_3B: return "3B"; @@ -5097,6 +5219,7 @@ static const char * llama_model_vocab_type_name(enum llama_vocab_type type){ case LLAMA_VOCAB_TYPE_BPE: return "BPE"; case LLAMA_VOCAB_TYPE_WPM: return "WPM"; case LLAMA_VOCAB_TYPE_UGM: return "UGM"; + case LLAMA_VOCAB_TYPE_RWKV: return "RWKV"; default: return "unknown"; } } @@ -5793,6 +5916,26 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_RWKV6: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_WKV_HEAD_SIZE, hparams.wkv_head_size); + ml.get_key(LLM_KV_TIME_MIX_EXTRA_DIM, hparams.time_mix_extra_dim); + ml.get_key(LLM_KV_TIME_DECAY_EXTRA_DIM, hparams.time_decay_extra_dim); + ml.get_key(LLM_KV_RESCALE_EVERY_N_LAYERS, hparams.rescale_every_n_layers, false); + + switch (hparams.n_layer) { + case 24: model.type = e_model::MODEL_1_6B; break; + case 32: + switch (hparams.n_embd) { + case 2560: model.type = e_model::MODEL_3B; break; + case 4096: model.type = e_model::MODEL_7B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } break; + case 61: model.type = e_model::MODEL_14B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; default: (void)0; } @@ -5922,6 +6065,15 @@ static void llm_load_vocab( } #endif } + } else if (tokenizer_model == "rwkv") { + vocab.type = LLAMA_VOCAB_TYPE_RWKV; + + // default special tokens + vocab.special_bos_id = -1; + vocab.special_eos_id = -1; + vocab.special_unk_id = -1; + vocab.special_sep_id = -1; + vocab.special_pad_id = -1; } else { throw std::runtime_error(format("unknown tokenizer: '%s'", tokenizer_model.c_str())); } @@ -6053,6 +6205,12 @@ static void llm_load_vocab( vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; vocab.tokenizer_add_bos = false; vocab.tokenizer_add_eos = true; + } else if (vocab.type == LLAMA_VOCAB_TYPE_RWKV) { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; + vocab.tokenizer_add_space_prefix = false; + vocab.tokenizer_clean_spaces = false; + vocab.tokenizer_add_bos = false; + vocab.tokenizer_add_eos = false; } else { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; } @@ -6157,6 +6315,10 @@ static void llm_load_vocab( } } else if (vocab.type == LLAMA_VOCAB_TYPE_WPM) { vocab.linefeed_id = vocab.special_pad_id; + } else if (vocab.type == LLAMA_VOCAB_TYPE_RWKV) { + const std::vector ids = llama_tokenize_internal(vocab, "\n", false); + GGML_ASSERT(!ids.empty() && "model vocab missing newline token"); + vocab.linefeed_id = ids[0]; } else { const std::vector ids = llama_tokenize_internal(vocab, "\xC4\x8A", false); // U+010A GGML_ASSERT(!ids.empty() && "model vocab missing newline token"); @@ -8203,6 +8365,68 @@ static bool llm_load_tensors( layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); } } break; + case LLM_ARCH_RWKV6: + { + model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + + // Block 0, LN0 + model.tok_norm = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}); + model.tok_norm_b = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}); + + // output + model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); + model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); + + const int time_mix_extra_dim = hparams.time_mix_extra_dim; + const int time_decay_extra_dim = hparams.time_decay_extra_dim; + const int head_size = hparams.wkv_head_size; + const int attn_hidden_size = n_embd; + const int ffn_size = hparams.n_ff_arr[0]; + + for (int i = 0; i < n_layer; ++i) { + ggml_context * ctx_layer = ctx_for_layer(i); + + auto & layer = model.layers[i]; + + layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}); + + layer.attn_norm_2 = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}); + layer.attn_norm_2_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}); + + layer.time_mix_w1 = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_W1, "weight", i), {n_embd, time_mix_extra_dim * 5}); + layer.time_mix_w2 = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {time_mix_extra_dim, n_embd, 5}); + + layer.time_mix_lerp_x = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_LERP_X, "weight", i), {n_embd, 1, 1}); + layer.time_mix_lerp_w = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_LERP_W, "weight", i), {n_embd, 1, 1}); + layer.time_mix_lerp_k = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_LERP_K, "weight", i), {n_embd, 1, 1}); + layer.time_mix_lerp_v = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_LERP_V, "weight", i), {n_embd, 1, 1}); + layer.time_mix_lerp_r = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_LERP_R, "weight", i), {n_embd, 1, 1}); + layer.time_mix_lerp_g = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_LERP_G, "weight", i), {n_embd, 1, 1}); + + layer.time_mix_first = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_FIRST, "weight", i), {head_size, n_embd / head_size}); + layer.time_mix_decay = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_DECAY, "weight", i), {n_embd}); + layer.time_mix_decay_w1 = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_DECAY_W1, "weight", i), {n_embd, time_decay_extra_dim}); + layer.time_mix_decay_w2 = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_DECAY_W2, "weight", i), {time_decay_extra_dim, attn_hidden_size}); + layer.time_mix_key = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_KEY, "weight", i), {attn_hidden_size, n_embd}); + layer.time_mix_value = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_VALUE, "weight", i), {attn_hidden_size, n_embd}); + layer.time_mix_receptance = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "weight", i), {attn_hidden_size, n_embd}); + layer.time_mix_gate = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_GATE, "weight", i), {attn_hidden_size, n_embd}); + + layer.time_mix_ln = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_LN, "weight", i), {n_embd}); + layer.time_mix_ln_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_LN, "bias", i), {n_embd}); + layer.time_mix_output = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_OUTPUT, "weight", i), {n_embd, attn_hidden_size}); + + layer.channel_mix_lerp_k = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_CHANNEL_MIX_LERP_K, "weight", i), {n_embd, 1, 1}); + layer.channel_mix_lerp_r = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_CHANNEL_MIX_LERP_R, "weight", i), {n_embd, 1, 1}); + + layer.channel_mix_key = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_CHANNEL_MIX_KEY, "weight", i), {n_embd, ffn_size}); + layer.channel_mix_value = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_CHANNEL_MIX_VALUE, "weight", i), {ffn_size, n_embd}); + layer.channel_mix_receptance = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_CHANNEL_MIX_RECEPTANCE, "weight", i), {n_embd, n_embd}); + } + + } break; default: throw std::runtime_error("unknown architecture"); } @@ -9162,6 +9386,171 @@ static struct ggml_tensor * llm_build_mamba( return cur; } +static struct ggml_tensor * llm_build_rwkv6_time_mix( + struct llama_context & lctx, + struct ggml_context * ctx, + const struct llama_layer * layer, + struct ggml_tensor * cur, + struct ggml_tensor * x_prev, + struct ggml_tensor ** wkv_state) { + size_t n_embed = cur->ne[0]; + size_t n_seq_tokens = cur->ne[1]; + size_t n_seqs = cur->ne[2]; + + size_t head_size = layer->time_mix_first->ne[0]; + size_t head_count = layer->time_mix_first->ne[1]; + + size_t n_tokens = n_seqs * n_seq_tokens; + + struct ggml_tensor * sx = ggml_sub(ctx, x_prev, cur); + + sx = ggml_reshape_2d(ctx, sx, n_embed, n_tokens); + cur = ggml_reshape_2d(ctx, cur, n_embed, n_tokens); + + struct ggml_tensor * xxx = ggml_add(ctx, ggml_mul(ctx, sx, layer->time_mix_lerp_x), cur); + + xxx = ggml_reshape_4d( + ctx, + ggml_tanh( + ctx, + ggml_mul_mat(ctx, layer->time_mix_w1, xxx) + ), + layer->time_mix_w1->ne[1] / 5, 1, 5, n_tokens + ); + + xxx = ggml_cont(ctx, ggml_permute(ctx, xxx, 0, 1, 3, 2)); + + xxx = ggml_mul_mat( + ctx, + ggml_reshape_4d( + ctx, + layer->time_mix_w2, + layer->time_mix_w2->ne[0], layer->time_mix_w2->ne[1], 1, 5 + ), + xxx + ); + + struct ggml_tensor *mw = ggml_view_2d(ctx, xxx, n_embed, n_tokens, xxx->nb[1], 0); + struct ggml_tensor *mk = ggml_view_2d(ctx, xxx, n_embed, n_tokens, xxx->nb[1], n_embed * n_tokens * sizeof(float)); + struct ggml_tensor *mv = ggml_view_2d(ctx, xxx, n_embed, n_tokens, xxx->nb[1], n_embed * n_tokens * 2 * sizeof(float)); + struct ggml_tensor *mr = ggml_view_2d(ctx, xxx, n_embed, n_tokens, xxx->nb[1], n_embed * n_tokens * 3 * sizeof(float)); + struct ggml_tensor *mg = ggml_view_2d(ctx, xxx, n_embed, n_tokens, xxx->nb[1], n_embed * n_tokens * 4 * sizeof(float)); + + struct ggml_tensor * xw = ggml_add( + ctx, + ggml_mul( + ctx, + ggml_add(ctx, mw, layer->time_mix_lerp_w), + sx + ), + cur + ); + + struct ggml_tensor * xk = ggml_add( + ctx, + ggml_mul( + ctx, + ggml_add(ctx, mk, layer->time_mix_lerp_k), + sx + ), + cur + ); + + struct ggml_tensor * xv = ggml_add( + ctx, + ggml_mul( + ctx, + ggml_add(ctx, mv, layer->time_mix_lerp_v), + sx + ), + cur + ); + + struct ggml_tensor * xr = ggml_add( + ctx, + ggml_mul( + ctx, + ggml_add(ctx, mr, layer->time_mix_lerp_r), + sx + ), + cur + ); + + struct ggml_tensor * xg = ggml_add( + ctx, + ggml_mul( + ctx, + ggml_add(ctx, mg, layer->time_mix_lerp_g), + sx + ), + cur + ); + + struct ggml_tensor * r = ggml_reshape_4d(ctx, llm_build_lora_mm(lctx, ctx, layer->time_mix_receptance, xr), head_size, 1, head_count, n_tokens); + struct ggml_tensor * k = ggml_reshape_4d(ctx, llm_build_lora_mm(lctx, ctx, layer->time_mix_key, xk), 1, head_size, head_count, n_tokens); + struct ggml_tensor * v = ggml_reshape_4d(ctx, llm_build_lora_mm(lctx, ctx, layer->time_mix_value, xv), head_size, 1, head_count, n_tokens); + struct ggml_tensor * g = ggml_silu( + ctx, + llm_build_lora_mm(lctx, ctx, layer->time_mix_gate, xg) + ); + + struct ggml_tensor * w = ggml_mul_mat( + ctx, + layer->time_mix_decay_w2, + ggml_tanh( + ctx, + ggml_mul_mat(ctx, layer->time_mix_decay_w1, xw) + ) + ); + + w = ggml_add(ctx, w, ggml_reshape_1d(ctx, layer->time_mix_decay, n_embed)); + w = ggml_exp(ctx, ggml_neg(ctx, ggml_exp(ctx, w))); + w = ggml_reshape_4d(ctx, w, 1, head_size, head_count, n_tokens); + + k = ggml_transpose(ctx, k); + v = ggml_transpose(ctx, v); + r = ggml_transpose(ctx, r); + + struct ggml_tensor * wkv_output = ggml_rwkv_wkv(ctx, k, v, r, layer->time_mix_first, w, *wkv_state); + cur = ggml_view_1d(ctx, wkv_output, n_embed * n_tokens, 0); + *wkv_state = ggml_view_1d(ctx, wkv_output, n_embed * head_size * n_seqs, n_embed * n_tokens * sizeof(float)); + + // group norm with head_count groups + cur = ggml_reshape_3d(ctx, cur, n_embed / head_count, head_count, n_tokens); + cur = ggml_norm(ctx, cur, 64e-5f); + + // Convert back to regular vectors. + cur = ggml_reshape_2d(ctx, cur, n_embed, n_tokens); + cur = ggml_add(ctx, ggml_mul(ctx, cur, layer->time_mix_ln), layer->time_mix_ln_b); + + cur = ggml_mul(ctx, cur, g); + cur = llm_build_lora_mm(lctx, ctx, layer->time_mix_output, cur); + + return ggml_reshape_3d(ctx, cur, n_embed, n_seq_tokens, n_seqs); +} + +static struct ggml_tensor * llm_build_rwkv6_channel_mix( + struct llama_context & lctx, + struct ggml_context * ctx, + const struct llama_layer * layer, + struct ggml_tensor * cur, + struct ggml_tensor * x_prev) { + struct ggml_tensor * sx = ggml_sub(ctx, x_prev, cur); + struct ggml_tensor * xk = ggml_add(ctx, ggml_mul(ctx, sx, layer->channel_mix_lerp_k), cur); + struct ggml_tensor * xr = ggml_add(ctx, ggml_mul(ctx, sx, layer->channel_mix_lerp_r), cur); + + struct ggml_tensor * r = ggml_sigmoid(ctx, llm_build_lora_mm(lctx, ctx, layer->channel_mix_receptance, xr)); + struct ggml_tensor * k = ggml_sqr( + ctx, + ggml_relu( + ctx, + llm_build_lora_mm(lctx, ctx, layer->channel_mix_key, xk) + ) + ); + + return ggml_mul(ctx, r, llm_build_lora_mm(lctx, ctx, layer->channel_mix_value, k)); +} + struct llm_build_context { const llama_model & model; llama_context & lctx; @@ -14683,6 +15072,117 @@ struct llm_build_context { return gf; } + + ggml_cgraph * build_rwkv6() { + ggml_cgraph *gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + + // Token shift state dimensions should be 2 * n_emb + GGML_ASSERT(n_embd == hparams.n_embd_k_s() / 2); + + const int64_t n_seqs = batch.n_seqs; + const int64_t n_seq_tokens = batch.n_seq_tokens; + const int64_t n_tokens = batch.n_tokens; + GGML_ASSERT(n_seqs != 0); + GGML_ASSERT(batch.equal_seqs); + GGML_ASSERT(n_tokens == n_seq_tokens * n_seqs); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + struct ggml_tensor * state_copy = build_inp_s_copy(); + struct ggml_tensor * state_mask = build_inp_s_mask(); + + inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + inpL = llm_build_norm(ctx0, inpL, hparams, model.tok_norm, model.tok_norm_b, LLM_NORM, cb, -1); + + for (int il = 0; il < n_layer; ++il) { + const llama_layer * layer = &model.layers[il]; + + // (ab)using the KV cache to store the states + struct ggml_tensor * token_shift = llm_build_copy_mask_state(ctx0, + gf, kv_self.k_l[il], state_copy, state_mask, + hparams.n_embd_k_s(), kv_self.size, kv_head, n_kv, n_seqs); + struct ggml_tensor * wkv_states = llm_build_copy_mask_state(ctx0, + gf, kv_self.v_l[il], state_copy, state_mask, + hparams.n_embd_v_s(), kv_self.size, kv_head, n_kv, n_seqs); + + cur = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs); + token_shift = ggml_reshape_3d(ctx0, token_shift, n_embd, 2, n_seqs); + + struct ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0); + struct ggml_tensor * ffn_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], n_embd * ggml_element_size(token_shift)); + + struct ggml_tensor * x_norm_att = llm_build_norm(ctx0, cur, hparams, layer->attn_norm, layer->attn_norm_b, LLM_NORM, cb, il); + struct ggml_tensor * x_prev = ggml_concat( + ctx0, + att_shift, + ggml_view_3d(ctx0, x_norm_att, n_embd, n_seq_tokens - 1, n_seqs, x_norm_att->nb[1], x_norm_att->nb[2], 0), + 1 + ); + + cur = ggml_add(ctx0, cur, llm_build_rwkv6_time_mix(lctx, ctx0, layer, x_norm_att, x_prev, &wkv_states)); + ggml_build_forward_expand(gf, cur); + ggml_build_forward_expand( + gf, + ggml_cpy( + ctx0, + wkv_states, + ggml_view_1d( + ctx0, + kv_self.v_l[il], + hparams.n_embd_v_s() * n_seqs, + hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_self.v_l[il]) + ) + ) + ); + + struct ggml_tensor * x_norm_ffn = llm_build_norm(ctx0, cur, hparams, layer->attn_norm_2, layer->attn_norm_2_b, LLM_NORM, cb, il); + x_prev = ggml_concat( + ctx0, + ffn_shift, + ggml_view_3d(ctx0, x_norm_ffn, n_embd, n_seq_tokens - 1, n_seqs, x_norm_ffn->nb[1], x_norm_ffn->nb[2], 0), + 1 + ); + cur = ggml_add(ctx0, cur, llm_build_rwkv6_channel_mix(lctx, ctx0, layer, x_norm_ffn, x_prev)); + ggml_build_forward_expand(gf, cur); + + struct ggml_tensor * last_norm_att = ggml_view_3d(ctx0, x_norm_att, n_embd, 1, n_seqs, x_norm_att->nb[1], x_norm_att->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(x_norm_att)); + struct ggml_tensor * last_norm_ffn = ggml_view_3d(ctx0, x_norm_ffn, n_embd, 1, n_seqs, x_norm_ffn->nb[1], x_norm_ffn->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(x_norm_ffn)); + + token_shift = ggml_concat(ctx0, last_norm_att, last_norm_ffn, 1); + + ggml_build_forward_expand( + gf, + ggml_cpy( + ctx0, + ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * 2, 0), + ggml_view_1d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s() * n_seqs, hparams.n_embd_k_s() * kv_head * ggml_element_size(kv_self.k_l[il])) + ) + ); + + if (hparams.rescale_every_n_layers != 0 && (il + 1) % hparams.rescale_every_n_layers == 0) { + cur = ggml_scale(ctx0, cur, 0.5F); + } + + cur = lctx.cvec.apply_to(ctx0, cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + + cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, model.output_norm_b, LLM_NORM, cb, -1); + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); + + cb(cur, "result_output", -1); + ggml_build_forward_expand(gf, cur); + + return gf; + } }; static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector & ids) { @@ -14929,6 +15429,10 @@ static struct ggml_cgraph * llama_build_graph( { result = llm.build_exaone(); } break; + case LLM_ARCH_RWKV6: + { + result = llm.build_rwkv6(); + } break; default: GGML_ABORT("fatal error"); } @@ -16973,6 +17477,11 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s // NOTE: can't use LLM_TN here because the layer number is not known quantize &= name.find("ssm_conv1d.weight") == std::string::npos; + // do not quantize RWKV's time_mix_first tensors + quantize &= name.find("time_mix_first.weight") == std::string::npos; + quantize &= name.find("time_mix_w1.weight") == std::string::npos; + quantize &= name.find("time_mix_w2.weight") == std::string::npos; + // do not quantize relative position bias (T5) quantize &= name.find("attn_rel_b.weight") == std::string::npos; @@ -17977,6 +18486,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { case LLM_ARCH_T5: case LLM_ARCH_T5ENCODER: case LLM_ARCH_JAIS: + case LLM_ARCH_RWKV6: return LLAMA_ROPE_TYPE_NONE; // use what we call a normal RoPE, operating on pairs of consecutive head values @@ -18145,6 +18655,7 @@ llama_token llama_model_decoder_start_token(const struct llama_model * model) { bool llama_model_is_recurrent(const struct llama_model * model) { switch (model->arch) { case LLM_ARCH_MAMBA: return true; + case LLM_ARCH_RWKV6: return true; default: return false; } } From c6d4cb46559b359d2682cf2a002e7fe01bb7a767 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 2 Sep 2024 11:52:04 +0300 Subject: [PATCH 06/22] llama : minor style --- src/llama.cpp | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 2113c72f3..4e203471c 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -8711,8 +8711,7 @@ static void llm_build_kv_store( GGML_ASSERT(kv.size == n_ctx); - struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, kv.k_l[il], n_tokens*n_embd_k_gqa, - (ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa))*kv_head); + struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, kv.k_l[il], n_tokens*n_embd_k_gqa, ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa)*kv_head); cb(k_cache_view, "k_cache_view", il); // note: storing RoPE-ed version of K in the KV cache @@ -8723,8 +8722,7 @@ static void llm_build_kv_store( struct ggml_tensor * v_cache_view = nullptr; if (cparams.flash_attn) { - v_cache_view = ggml_view_1d(ctx, kv.v_l[il], n_tokens*n_embd_v_gqa, - (kv_head)*ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa)); + v_cache_view = ggml_view_1d(ctx, kv.v_l[il], n_tokens*n_embd_v_gqa, ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa)*kv_head); } else { // note: the V cache is transposed when not using flash attention v_cache_view = ggml_view_2d(ctx, kv.v_l[il], n_tokens, n_embd_v_gqa, @@ -9211,8 +9209,7 @@ static struct ggml_tensor * llm_build_kv( struct ggml_tensor * cur; - cur = llm_build_kqv(ctx, lctx, kv, graph, wo, wo_b, - q_cur, kq_mask, n_tokens, n_kv, kq_scale, cb, il); + cur = llm_build_kqv(ctx, lctx, kv, graph, wo, wo_b, q_cur, kq_mask, n_tokens, n_kv, kq_scale, cb, il); cb(cur, "kqv_out", il); return cur; From 9c1ba557335c3cab46807e6478eb7d17a63361f1 Mon Sep 17 00:00:00 2001 From: Tushar Date: Mon, 2 Sep 2024 16:51:01 +0530 Subject: [PATCH 07/22] build(nix): Package gguf-py (#5664) * style: format with nixfmt/rfc101-style * build(nix): Package gguf-py * build(nix): Refactor to new scope for gguf-py * build(nix): Exclude gguf-py from devShells * build(nix): Refactor gguf-py derivation to take in exact deps * build(nix): Enable pytestCheckHook and pythonImportsCheck for gguf-py * build(python): Package python scripts with pyproject.toml * chore: Cleanup * dev(nix): Break up python/C devShells * build(python): Relax pytorch version constraint Nix has an older version * chore: Move cmake to nativeBuildInputs for devShell * fmt: Reconcile formatting with rebase * style: nix fmt * cleanup: Remove unncessary __init__.py * chore: Suggestions from review - Filter out non-source files from llama-scripts flake derivation - Clean up unused closure - Remove scripts devShell * revert: Bad changes * dev: Simplify devShells, restore the -extra devShell * build(nix): Add pyyaml for gguf-py * chore: Remove some unused bindings * dev: Add tiktoken to -extra devShells --- .devops/nix/devshells.nix | 53 ++++- .devops/nix/nixpkgs-instances.nix | 18 +- .devops/nix/package-gguf-py.nix | 36 +++ .devops/nix/package.nix | 364 ++++++++++++------------------ .devops/nix/python-scripts.nix | 66 ++++++ .devops/nix/scope.nix | 40 +++- flake.nix | 5 +- gguf-py/pyproject.toml | 1 + pyproject.toml | 2 +- 9 files changed, 337 insertions(+), 248 deletions(-) create mode 100644 .devops/nix/package-gguf-py.nix create mode 100644 .devops/nix/python-scripts.nix diff --git a/.devops/nix/devshells.nix b/.devops/nix/devshells.nix index 1862f0f08..bfd304af1 100644 --- a/.devops/nix/devshells.nix +++ b/.devops/nix/devshells.nix @@ -1,13 +1,52 @@ +{ inputs, ... }: + { perSystem = - { config, lib, ... }: + { + config, + lib, + system, + ... + }: { devShells = - lib.concatMapAttrs - (name: package: { - ${name} = package.passthru.shell; - ${name + "-extra"} = package.passthru.shell-extra; - }) - config.packages; + let + pkgs = import inputs.nixpkgs { inherit system; }; + stdenv = pkgs.stdenv; + scripts = config.packages.python-scripts; + in + lib.pipe (config.packages) [ + (lib.concatMapAttrs ( + name: package: { + ${name} = pkgs.mkShell { + name = "${name}"; + inputsFrom = [ package ]; + shellHook = '' + echo "Entering ${name} devShell" + ''; + }; + "${name}-extra" = + if (name == "python-scripts") then + null + else + pkgs.mkShell { + name = "${name}-extra"; + inputsFrom = [ + package + scripts + ]; + # Extra packages that *may* be used by some scripts + packages = [ + pkgs.python3Packages.tiktoken + ]; + shellHook = '' + echo "Entering ${name} devShell" + addToSearchPath "LD_LIBRARY_PATH" "${lib.getLib stdenv.cc.cc}/lib" + ''; + }; + } + )) + (lib.filterAttrs (name: value: value != null)) + ]; }; } diff --git a/.devops/nix/nixpkgs-instances.nix b/.devops/nix/nixpkgs-instances.nix index 4a2f81c4b..90d683a71 100644 --- a/.devops/nix/nixpkgs-instances.nix +++ b/.devops/nix/nixpkgs-instances.nix @@ -26,16 +26,14 @@ config.cudaSupport = true; config.allowUnfreePredicate = p: - builtins.all - ( - license: - license.free - || builtins.elem license.shortName [ - "CUDA EULA" - "cuDNN EULA" - ] - ) - (p.meta.licenses or [ p.meta.license ]); + builtins.all ( + license: + license.free + || builtins.elem license.shortName [ + "CUDA EULA" + "cuDNN EULA" + ] + ) (p.meta.licenses or [ p.meta.license ]); }; # Ensure dependencies use ROCm consistently pkgsRocm = import inputs.nixpkgs { diff --git a/.devops/nix/package-gguf-py.nix b/.devops/nix/package-gguf-py.nix new file mode 100644 index 000000000..cca2f36a5 --- /dev/null +++ b/.devops/nix/package-gguf-py.nix @@ -0,0 +1,36 @@ +{ + lib, + llamaVersion, + numpy, + tqdm, + sentencepiece, + pyyaml, + poetry-core, + buildPythonPackage, + pytestCheckHook, +}: + +buildPythonPackage { + pname = "gguf"; + version = llamaVersion; + pyproject = true; + nativeBuildInputs = [ poetry-core ]; + propagatedBuildInputs = [ + numpy + tqdm + sentencepiece + pyyaml + ]; + src = lib.cleanSource ../../gguf-py; + pythonImportsCheck = [ + "numpy" + "gguf" + ]; + nativeCheckInputs = [ pytestCheckHook ]; + doCheck = true; + meta = with lib; { + description = "Python package for writing binary files in the GGUF format"; + license = licenses.mit; + maintainers = [ maintainers.ditsuke ]; + }; +} diff --git a/.devops/nix/package.nix b/.devops/nix/package.nix index cfffac257..5d7d7ea5a 100644 --- a/.devops/nix/package.nix +++ b/.devops/nix/package.nix @@ -3,13 +3,11 @@ glibc, config, stdenv, - mkShell, runCommand, cmake, ninja, pkg-config, git, - python3, mpi, blas, cudaPackages, @@ -20,15 +18,18 @@ vulkan-loader, curl, shaderc, - useBlas ? builtins.all (x: !x) [ - useCuda - useMetalKit - useRocm - useVulkan - ] && blas.meta.available, + useBlas ? + builtins.all (x: !x) [ + useCuda + useMetalKit + useRocm + useVulkan + ] + && blas.meta.available, useCuda ? config.cudaSupport, useMetalKit ? stdenv.isAarch64 && stdenv.isDarwin, - useMpi ? false, # Increases the runtime closure size by ~700M + # Increases the runtime closure size by ~700M + useMpi ? false, useRocm ? config.rocmSupport, enableCurl ? true, useVulkan ? false, @@ -38,8 +39,8 @@ # otherwise we get libstdc++ errors downstream. effectiveStdenv ? if useCuda then cudaPackages.backendStdenv else stdenv, enableStatic ? effectiveStdenv.hostPlatform.isStatic, - precompileMetalShaders ? false -}@inputs: + precompileMetalShaders ? false, +}: let inherit (lib) @@ -47,7 +48,6 @@ let cmakeFeature optionals strings - versionOlder ; stdenv = throw "Use effectiveStdenv instead"; @@ -63,54 +63,11 @@ let pnameSuffix = strings.optionalString (suffices != [ ]) "-${strings.concatMapStringsSep "-" strings.toLower suffices}"; - descriptionSuffix = - strings.optionalString (suffices != [ ]) - ", accelerated with ${strings.concatStringsSep ", " suffices}"; + descriptionSuffix = strings.optionalString ( + suffices != [ ] + ) ", accelerated with ${strings.concatStringsSep ", " suffices}"; - executableSuffix = effectiveStdenv.hostPlatform.extensions.executable; - - # TODO: package the Python in this repository in a Nix-like way. - # It'd be nice to migrate to buildPythonPackage, as well as ensure this repo - # is PEP 517-compatible, and ensure the correct .dist-info is generated. - # https://peps.python.org/pep-0517/ - # - # TODO: Package up each Python script or service appropriately, by making - # them into "entrypoints" - llama-python = python3.withPackages ( - ps: [ - ps.numpy - ps.sentencepiece - ] - ); - - # TODO(Green-Sky): find a better way to opt-into the heavy ml python runtime - llama-python-extra = python3.withPackages ( - ps: [ - ps.numpy - ps.sentencepiece - ps.tiktoken - ps.torchWithoutCuda - ps.transformers - - # server bench - ps.matplotlib - - # server tests - ps.openai - ps.behave - ps.prometheus-client - - # for examples/pydantic-models-to-grammar-examples.py - ps.docstring-parser - ps.pydantic - - # for scripts/compare-llama-bench.py - ps.gitpython - ps.tabulate - ] - ); - - xcrunHost = runCommand "xcrunHost" {} '' + xcrunHost = runCommand "xcrunHost" { } '' mkdir -p $out/bin ln -s /usr/bin/xcrun $out/bin ''; @@ -145,178 +102,145 @@ let ]; in -effectiveStdenv.mkDerivation ( - finalAttrs: { - pname = "llama-cpp${pnameSuffix}"; - version = llamaVersion; +effectiveStdenv.mkDerivation (finalAttrs: { + pname = "llama-cpp${pnameSuffix}"; + version = llamaVersion; - # Note: none of the files discarded here are visible in the sandbox or - # affect the output hash. This also means they can be modified without - # triggering a rebuild. - src = lib.cleanSourceWith { - filter = - name: type: - let - noneOf = builtins.all (x: !x); - baseName = baseNameOf name; - in - noneOf [ - (lib.hasSuffix ".nix" name) # Ignore *.nix files when computing outPaths - (lib.hasSuffix ".md" name) # Ignore *.md changes whe computing outPaths - (lib.hasPrefix "." baseName) # Skip hidden files and directories - (baseName == "flake.lock") - ]; - src = lib.cleanSource ../../.; - }; - - postPatch = '' - substituteInPlace ./ggml/src/ggml-metal.m \ - --replace '[bundle pathForResource:@"ggml-metal" ofType:@"metal"];' "@\"$out/bin/ggml-metal.metal\";" - substituteInPlace ./ggml/src/ggml-metal.m \ - --replace '[bundle pathForResource:@"default" ofType:@"metallib"];' "@\"$out/bin/default.metallib\";" - ''; - - # With PR#6015 https://github.com/ggerganov/llama.cpp/pull/6015, - # `default.metallib` may be compiled with Metal compiler from XCode - # and we need to escape sandbox on MacOS to access Metal compiler. - # `xcrun` is used find the path of the Metal compiler, which is varible - # and not on $PATH - # see https://github.com/ggerganov/llama.cpp/pull/6118 for discussion - __noChroot = effectiveStdenv.isDarwin && useMetalKit && precompileMetalShaders; - - nativeBuildInputs = - [ - cmake - ninja - pkg-config - git - ] - ++ optionals useCuda [ - cudaPackages.cuda_nvcc - autoAddDriverRunpath - ] - ++ optionals (effectiveStdenv.hostPlatform.isGnu && enableStatic) [ - glibc.static - ] ++ optionals (effectiveStdenv.isDarwin && useMetalKit && precompileMetalShaders) [ - xcrunHost + # Note: none of the files discarded here are visible in the sandbox or + # affect the output hash. This also means they can be modified without + # triggering a rebuild. + src = lib.cleanSourceWith { + filter = + name: type: + let + noneOf = builtins.all (x: !x); + baseName = baseNameOf name; + in + noneOf [ + (lib.hasSuffix ".nix" name) # Ignore *.nix files when computing outPaths + (lib.hasSuffix ".md" name) # Ignore *.md changes whe computing outPaths + (lib.hasPrefix "." baseName) # Skip hidden files and directories + (baseName == "flake.lock") ]; + src = lib.cleanSource ../../.; + }; - buildInputs = - optionals effectiveStdenv.isDarwin darwinBuildInputs - ++ optionals useCuda cudaBuildInputs - ++ optionals useMpi [ mpi ] - ++ optionals useRocm rocmBuildInputs - ++ optionals useBlas [ blas ] - ++ optionals useVulkan vulkanBuildInputs - ++ optionals enableCurl [ curl ]; + postPatch = '' + substituteInPlace ./ggml/src/ggml-metal.m \ + --replace '[bundle pathForResource:@"ggml-metal" ofType:@"metal"];' "@\"$out/bin/ggml-metal.metal\";" + substituteInPlace ./ggml/src/ggml-metal.m \ + --replace '[bundle pathForResource:@"default" ofType:@"metallib"];' "@\"$out/bin/default.metallib\";" + ''; - cmakeFlags = - [ - (cmakeBool "LLAMA_BUILD_SERVER" true) - (cmakeBool "BUILD_SHARED_LIBS" (!enableStatic)) - (cmakeBool "CMAKE_SKIP_BUILD_RPATH" true) - (cmakeBool "LLAMA_CURL" enableCurl) - (cmakeBool "GGML_NATIVE" false) - (cmakeBool "GGML_BLAS" useBlas) - (cmakeBool "GGML_CUDA" useCuda) - (cmakeBool "GGML_HIPBLAS" useRocm) - (cmakeBool "GGML_METAL" useMetalKit) - (cmakeBool "GGML_VULKAN" useVulkan) - (cmakeBool "GGML_STATIC" enableStatic) - ] - ++ optionals useCuda [ - ( - with cudaPackages.flags; - cmakeFeature "CMAKE_CUDA_ARCHITECTURES" ( - builtins.concatStringsSep ";" (map dropDot cudaCapabilities) - ) + # With PR#6015 https://github.com/ggerganov/llama.cpp/pull/6015, + # `default.metallib` may be compiled with Metal compiler from XCode + # and we need to escape sandbox on MacOS to access Metal compiler. + # `xcrun` is used find the path of the Metal compiler, which is varible + # and not on $PATH + # see https://github.com/ggerganov/llama.cpp/pull/6118 for discussion + __noChroot = effectiveStdenv.isDarwin && useMetalKit && precompileMetalShaders; + + nativeBuildInputs = + [ + cmake + ninja + pkg-config + git + ] + ++ optionals useCuda [ + cudaPackages.cuda_nvcc + + autoAddDriverRunpath + ] + ++ optionals (effectiveStdenv.hostPlatform.isGnu && enableStatic) [ glibc.static ] + ++ optionals (effectiveStdenv.isDarwin && useMetalKit && precompileMetalShaders) [ xcrunHost ]; + + buildInputs = + optionals effectiveStdenv.isDarwin darwinBuildInputs + ++ optionals useCuda cudaBuildInputs + ++ optionals useMpi [ mpi ] + ++ optionals useRocm rocmBuildInputs + ++ optionals useBlas [ blas ] + ++ optionals useVulkan vulkanBuildInputs + ++ optionals enableCurl [ curl ]; + + cmakeFlags = + [ + (cmakeBool "LLAMA_BUILD_SERVER" true) + (cmakeBool "BUILD_SHARED_LIBS" (!enableStatic)) + (cmakeBool "CMAKE_SKIP_BUILD_RPATH" true) + (cmakeBool "LLAMA_CURL" enableCurl) + (cmakeBool "GGML_NATIVE" false) + (cmakeBool "GGML_BLAS" useBlas) + (cmakeBool "GGML_CUDA" useCuda) + (cmakeBool "GGML_HIPBLAS" useRocm) + (cmakeBool "GGML_METAL" useMetalKit) + (cmakeBool "GGML_VULKAN" useVulkan) + (cmakeBool "GGML_STATIC" enableStatic) + ] + ++ optionals useCuda [ + ( + with cudaPackages.flags; + cmakeFeature "CMAKE_CUDA_ARCHITECTURES" ( + builtins.concatStringsSep ";" (map dropDot cudaCapabilities) ) - ] - ++ optionals useRocm [ - (cmakeFeature "CMAKE_HIP_COMPILER" "${rocmPackages.llvm.clang}/bin/clang") - (cmakeFeature "CMAKE_HIP_ARCHITECTURES" (builtins.concatStringsSep ";" rocmPackages.clr.gpuTargets)) - ] - ++ optionals useMetalKit [ - (lib.cmakeFeature "CMAKE_C_FLAGS" "-D__ARM_FEATURE_DOTPROD=1") - (cmakeBool "GGML_METAL_EMBED_LIBRARY" (!precompileMetalShaders)) - ]; + ) + ] + ++ optionals useRocm [ + (cmakeFeature "CMAKE_HIP_COMPILER" "${rocmPackages.llvm.clang}/bin/clang") + (cmakeFeature "CMAKE_HIP_ARCHITECTURES" (builtins.concatStringsSep ";" rocmPackages.clr.gpuTargets)) + ] + ++ optionals useMetalKit [ + (lib.cmakeFeature "CMAKE_C_FLAGS" "-D__ARM_FEATURE_DOTPROD=1") + (cmakeBool "GGML_METAL_EMBED_LIBRARY" (!precompileMetalShaders)) + ]; - # Environment variables needed for ROCm - env = optionals useRocm { - ROCM_PATH = "${rocmPackages.clr}"; - HIP_DEVICE_LIB_PATH = "${rocmPackages.rocm-device-libs}/amdgcn/bitcode"; - }; + # Environment variables needed for ROCm + env = optionals useRocm { + ROCM_PATH = "${rocmPackages.clr}"; + HIP_DEVICE_LIB_PATH = "${rocmPackages.rocm-device-libs}/amdgcn/bitcode"; + }; - # TODO(SomeoneSerge): It's better to add proper install targets at the CMake level, - # if they haven't been added yet. - postInstall = '' - mkdir -p $out/include - cp $src/include/llama.h $out/include/ - ''; + # TODO(SomeoneSerge): It's better to add proper install targets at the CMake level, + # if they haven't been added yet. + postInstall = '' + mkdir -p $out/include + cp $src/include/llama.h $out/include/ + ''; - # Define the shells here, but don't add in the inputsFrom to avoid recursion. - passthru = { - inherit - useBlas - useCuda - useMetalKit - useMpi - useRocm - useVulkan - ; + meta = { + # Configurations we don't want even the CI to evaluate. Results in the + # "unsupported platform" messages. This is mostly a no-op, because + # cudaPackages would've refused to evaluate anyway. + badPlatforms = optionals useCuda lib.platforms.darwin; - shell = mkShell { - name = "shell-${finalAttrs.finalPackage.name}"; - description = "contains numpy and sentencepiece"; - buildInputs = [ llama-python ]; - inputsFrom = [ finalAttrs.finalPackage ]; - shellHook = '' - addToSearchPath "LD_LIBRARY_PATH" "${lib.getLib effectiveStdenv.cc.cc}/lib" - ''; - }; + # Configurations that are known to result in build failures. Can be + # overridden by importing Nixpkgs with `allowBroken = true`. + broken = (useMetalKit && !effectiveStdenv.isDarwin); - shell-extra = mkShell { - name = "shell-extra-${finalAttrs.finalPackage.name}"; - description = "contains numpy, sentencepiece, torchWithoutCuda, and transformers"; - buildInputs = [ llama-python-extra ]; - inputsFrom = [ finalAttrs.finalPackage ]; - }; - }; + description = "Inference of LLaMA model in pure C/C++${descriptionSuffix}"; + homepage = "https://github.com/ggerganov/llama.cpp/"; + license = lib.licenses.mit; - meta = { - # Configurations we don't want even the CI to evaluate. Results in the - # "unsupported platform" messages. This is mostly a no-op, because - # cudaPackages would've refused to evaluate anyway. - badPlatforms = optionals useCuda lib.platforms.darwin; + # Accommodates `nix run` and `lib.getExe` + mainProgram = "llama-cli"; - # Configurations that are known to result in build failures. Can be - # overridden by importing Nixpkgs with `allowBroken = true`. - broken = (useMetalKit && !effectiveStdenv.isDarwin); + # These people might respond, on the best effort basis, if you ping them + # in case of Nix-specific regressions or for reviewing Nix-specific PRs. + # Consider adding yourself to this list if you want to ensure this flake + # stays maintained and you're willing to invest your time. Do not add + # other people without their consent. Consider removing people after + # they've been unreachable for long periods of time. - description = "Inference of LLaMA model in pure C/C++${descriptionSuffix}"; - homepage = "https://github.com/ggerganov/llama.cpp/"; - license = lib.licenses.mit; + # Note that lib.maintainers is defined in Nixpkgs, but you may just add + # an attrset following the same format as in + # https://github.com/NixOS/nixpkgs/blob/f36a80e54da29775c78d7eff0e628c2b4e34d1d7/maintainers/maintainer-list.nix + maintainers = with lib.maintainers; [ + philiptaron + SomeoneSerge + ]; - # Accommodates `nix run` and `lib.getExe` - mainProgram = "llama-cli"; - - # These people might respond, on the best effort basis, if you ping them - # in case of Nix-specific regressions or for reviewing Nix-specific PRs. - # Consider adding yourself to this list if you want to ensure this flake - # stays maintained and you're willing to invest your time. Do not add - # other people without their consent. Consider removing people after - # they've been unreachable for long periods of time. - - # Note that lib.maintainers is defined in Nixpkgs, but you may just add - # an attrset following the same format as in - # https://github.com/NixOS/nixpkgs/blob/f36a80e54da29775c78d7eff0e628c2b4e34d1d7/maintainers/maintainer-list.nix - maintainers = with lib.maintainers; [ - philiptaron - SomeoneSerge - ]; - - # Extend `badPlatforms` instead - platforms = lib.platforms.all; - }; - } -) + # Extend `badPlatforms` instead + platforms = lib.platforms.all; + }; +}) diff --git a/.devops/nix/python-scripts.nix b/.devops/nix/python-scripts.nix new file mode 100644 index 000000000..392e9ffe4 --- /dev/null +++ b/.devops/nix/python-scripts.nix @@ -0,0 +1,66 @@ +{ + lib, + stdenv, + buildPythonPackage, + poetry-core, + mkShell, + python3Packages, + gguf-py, +}@inputs: + +let + llama-python-deps = with python3Packages; [ + numpy + sentencepiece + transformers + protobuf + torchWithoutCuda + gguf-py + tqdm + + # for scripts/compare-llama-bench.py + gitpython + tabulate + + # for examples/pydantic-models-to-grammar-examples.py + docstring-parser + pydantic + + ]; + + llama-python-test-deps = with python3Packages; [ + # Server bench + matplotlib + + # server tests + openai + behave + prometheus-client + ]; +in + +buildPythonPackage ({ + pname = "llama-scripts"; + version = "0.0.0"; + pyproject = true; + + # NOTE: The files filtered out here are not visible in the build sandbox, neither + # do they affect the output hash. They can be modified without triggering a rebuild. + src = lib.cleanSourceWith { + filter = + name: type: + let + any = builtins.any (x: x); + baseName = builtins.baseNameOf name; + in + any [ + (lib.hasSuffix ".py" name) + (baseName == "README.md") + (baseName == "pyproject.toml") + ]; + src = lib.cleanSource ../../.; + }; + nativeBuildInputs = [ poetry-core ]; + nativeCheckInputs = llama-python-test-deps; + dependencies = llama-python-deps; +}) diff --git a/.devops/nix/scope.nix b/.devops/nix/scope.nix index 78530c9e8..478e8c422 100644 --- a/.devops/nix/scope.nix +++ b/.devops/nix/scope.nix @@ -1,19 +1,41 @@ { lib, newScope, + python3, llamaVersion ? "0.0.0", }: +let + pythonPackages = python3.pkgs; + buildPythonPackage = pythonPackages.buildPythonPackage; + numpy = pythonPackages.numpy; + tqdm = pythonPackages.tqdm; + sentencepiece = pythonPackages.sentencepiece; + pyyaml = pythonPackages.pyyaml; + poetry-core = pythonPackages.poetry-core; + pytestCheckHook = pythonPackages.pytestCheckHook; +in + # We're using `makeScope` instead of just writing out an attrset # because it allows users to apply overlays later using `overrideScope'`. # Cf. https://noogle.dev/f/lib/makeScope -lib.makeScope newScope ( - self: { - inherit llamaVersion; - llama-cpp = self.callPackage ./package.nix { }; - docker = self.callPackage ./docker.nix { }; - docker-min = self.callPackage ./docker.nix { interactive = false; }; - sif = self.callPackage ./sif.nix { }; - } -) +lib.makeScope newScope (self: { + inherit llamaVersion; + gguf-py = self.callPackage ./package-gguf-py.nix { + inherit + buildPythonPackage + numpy + tqdm + sentencepiece + poetry-core + pyyaml + pytestCheckHook + ; + }; + python-scripts = self.callPackage ./python-scripts.nix { inherit buildPythonPackage poetry-core; }; + llama-cpp = self.callPackage ./package.nix { }; + docker = self.callPackage ./docker.nix { }; + docker-min = self.callPackage ./docker.nix { interactive = false; }; + sif = self.callPackage ./sif.nix { }; +}) diff --git a/flake.nix b/flake.nix index c69637d11..26a258816 100644 --- a/flake.nix +++ b/flake.nix @@ -145,7 +145,9 @@ # the same path you would with an overlay. legacyPackages = { llamaPackages = pkgs.callPackage .devops/nix/scope.nix { inherit llamaVersion; }; - llamaPackagesWindows = pkgs.pkgsCross.mingwW64.callPackage .devops/nix/scope.nix { inherit llamaVersion; }; + llamaPackagesWindows = pkgs.pkgsCross.mingwW64.callPackage .devops/nix/scope.nix { + inherit llamaVersion; + }; llamaPackagesCuda = pkgsCuda.callPackage .devops/nix/scope.nix { inherit llamaVersion; }; llamaPackagesRocm = pkgsRocm.callPackage .devops/nix/scope.nix { inherit llamaVersion; }; }; @@ -157,6 +159,7 @@ default = config.legacyPackages.llamaPackages.llama-cpp; vulkan = config.packages.default.override { useVulkan = true; }; windows = config.legacyPackages.llamaPackagesWindows.llama-cpp; + python-scripts = config.legacyPackages.llamaPackages.python-scripts; } // lib.optionalAttrs pkgs.stdenv.isLinux { cuda = config.legacyPackages.llamaPackagesCuda.llama-cpp; diff --git a/gguf-py/pyproject.toml b/gguf-py/pyproject.toml index eea381e5a..33cfe26b7 100644 --- a/gguf-py/pyproject.toml +++ b/gguf-py/pyproject.toml @@ -23,6 +23,7 @@ python = ">=3.8" numpy = ">=1.17" tqdm = ">=4.27" pyyaml = ">=5.1" +sentencepiece = ">=0.1.98,<=0.2.0" [tool.poetry.dev-dependencies] pytest = "^5.2" diff --git a/pyproject.toml b/pyproject.toml index 25e2e20b2..84e71de6d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,7 @@ classifiers = [ [tool.poetry.dependencies] python = ">=3.9" numpy = "^1.25.0" -sentencepiece = ">=0.1.98,<0.2.0" +sentencepiece = ">=0.1.98,<=0.2.0" transformers = ">=4.35.2,<5.0.0" protobuf = ">=4.21.0,<5.0.0" gguf = { path = "./gguf-py" } From b60074f1c26d8354053a952844ef3f7ebefd8803 Mon Sep 17 00:00:00 2001 From: Guoliang Hua <32868157+nbcsm@users.noreply.github.com> Date: Mon, 2 Sep 2024 20:36:43 +0800 Subject: [PATCH 08/22] llama-cli : remove duplicated log message (#9275) --- examples/main/main.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 2c05afb04..c55efbb66 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -386,8 +386,8 @@ int main(int argc, char ** argv) { } LOGLN( - "recalculate the cached logits (check): embd_inp.empty() %s, n_matching_session_tokens %zu, embd_inp.size() %zu, session_tokens.size() %zu, embd_inp.size() %zu", - log_tostr(embd_inp.empty()), n_matching_session_tokens, embd_inp.size(), session_tokens.size(), embd_inp.size()); + "recalculate the cached logits (check): embd_inp.empty() %s, n_matching_session_tokens %zu, embd_inp.size() %zu, session_tokens.size() %zu", + log_tostr(embd_inp.empty()), n_matching_session_tokens, embd_inp.size(), session_tokens.size()); // if we will use the cache for the full prompt without reaching the end of the cache, force // reevaluation of the last token to recalculate the cached logits From 6e7d133a5f9409dd257fad90d7f320721b07a1b2 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Mon, 2 Sep 2024 17:11:51 +0200 Subject: [PATCH 09/22] server : refactor multitask handling (#9274) * server : remove multitask from server_task * refactor completions handler * fix embeddings * use res_ok everywhere * small change for handle_slots_action * use unordered_set everywhere * (try) fix test * no more "mutable" lambda * Apply suggestions from code review Co-authored-by: Georgi Gerganov * use deque --------- Co-authored-by: Georgi Gerganov --- examples/server/server.cpp | 785 ++++++++---------- .../server/tests/features/parallel.feature | 4 +- examples/server/tests/features/steps/steps.py | 2 +- .../tests/features/wrong_usages.feature | 3 + examples/server/utils.hpp | 33 + 5 files changed, 365 insertions(+), 462 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index cc938e80d..109dbc023 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -5,13 +5,6 @@ #include "llama.h" #include "grammar-parser.h" -#ifndef NDEBUG -// crash the server in debug mode, otherwise send an http 500 error -#define CPPHTTPLIB_NO_EXCEPTIONS 1 -#endif -// increase max payload length to allow use of larger context size -#define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 1048576 -#include "httplib.h" // Change JSON_ASSERT from assert() to GGML_ASSERT: #define JSON_ASSERT GGML_ASSERT #include "json.hpp" @@ -39,11 +32,13 @@ #include #include #include -#include #include #include #include #include +#include +#include +#include using json = nlohmann::ordered_json; @@ -82,21 +77,33 @@ enum server_task_type { SERVER_TASK_TYPE_SET_LORA, }; +enum server_task_cmpl_type { + SERVER_TASK_CMPL_TYPE_NORMAL, + SERVER_TASK_CMPL_TYPE_EMBEDDING, + SERVER_TASK_CMPL_TYPE_INFILL, +}; + struct server_task { int id = -1; // to be filled by server_queue - int id_multi = -1; - int id_target = -1; + int id_target = -1; // used by SERVER_TASK_TYPE_CANCEL server_task_type type; json data; - bool infill = false; - bool embedding = false; + server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL; + + // utility function + static std::unordered_set get_list_id(const std::vector & tasks) { + std::unordered_set ids(tasks.size()); + for (size_t i = 0; i < tasks.size(); i++) { + ids.insert(tasks[i].id); + } + return ids; + } }; struct server_task_result { int id = -1; - int id_multi = -1; json data; @@ -104,13 +111,6 @@ struct server_task_result { bool error; }; -struct server_task_multi { - int id = -1; - - std::set subtasks_remaining; - std::vector results; -}; - struct slot_params { bool stream = true; bool cache_prompt = false; // remember the prompt to avoid reprocessing all prompt @@ -128,7 +128,9 @@ struct slot_params { struct server_slot { int id; int id_task = -1; - int id_multi = -1; + + // the index relative to completion multi-task request + size_t index = 0; struct slot_params params; @@ -158,8 +160,7 @@ struct server_slot { std::vector cache_tokens; std::vector generated_token_probs; - bool infill = false; - bool embedding = false; + server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL; bool has_next_token = true; bool truncated = false; bool stopped_eos = false; @@ -204,7 +205,7 @@ struct server_slot { n_past = 0; n_sent_text = 0; n_sent_token_probs = 0; - infill = false; + cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL; ga_i = 0; n_past_se = 0; @@ -383,38 +384,56 @@ struct server_queue { bool running; // queues - std::vector queue_tasks; - std::vector queue_tasks_deferred; - - std::vector queue_multitasks; + std::deque queue_tasks; + std::deque queue_tasks_deferred; std::mutex mutex_tasks; std::condition_variable condition_tasks; // callback functions - std::function callback_new_task; - std::function callback_finish_multitask; - std::function callback_update_slots; + std::function callback_new_task; + std::function callback_update_slots; // Add a new task to the end of the queue - int post(server_task task) { + int post(server_task task, bool front = false) { std::unique_lock lock(mutex_tasks); if (task.id == -1) { task.id = id++; LOG_VERBOSE("new task id", {{"new_id", task.id}}); } - queue_tasks.push_back(std::move(task)); + if (front) { + queue_tasks.push_front(std::move(task)); + } else { + queue_tasks.push_back(std::move(task)); + } condition_tasks.notify_one(); return task.id; } + // multi-task version of post() + int post(std::vector & tasks, bool front = false) { + for (auto & task : tasks) { + if (task.id == -1) { + task.id = id++; + LOG_VERBOSE("new task id", {{"new_id", task.id}}); + } + if (front) { + queue_tasks.push_front(std::move(task)); + } else { + queue_tasks.push_back(std::move(task)); + } + } + condition_tasks.notify_one(); + return 0; + } + // Add a new task, but defer until one slot is available void defer(server_task task) { std::unique_lock lock(mutex_tasks); queue_tasks_deferred.push_back(std::move(task)); } - // Get the next id for creating anew task + // Get the next id for creating a new task int get_new_id() { std::unique_lock lock(mutex_tasks); int new_id = id++; @@ -427,11 +446,6 @@ struct server_queue { callback_new_task = std::move(callback); } - // Register function to process a multitask when it is finished - void on_finish_multitask(std::function callback) { - callback_finish_multitask = std::move(callback); - } - // Register the function to be called when all slots data is ready to be processed void on_update_slots(std::function callback) { callback_update_slots = std::move(callback); @@ -480,22 +494,6 @@ struct server_queue { callback_new_task(task); } - LOG_VERBOSE("update_multitasks", {}); - - // check if we have any finished multitasks - auto queue_iterator = queue_multitasks.begin(); - while (queue_iterator != queue_multitasks.end()) { - if (queue_iterator->subtasks_remaining.empty()) { - // all subtasks done == multitask is done - server_task_multi current_multitask = *queue_iterator; - callback_finish_multitask(current_multitask); - // remove this multitask - queue_iterator = queue_multitasks.erase(queue_iterator); - } else { - ++queue_iterator; - } - } - // all tasks in the current loop is processed, slots data is now ready LOG_VERBOSE("callback_update_slots", {}); @@ -516,38 +514,11 @@ struct server_queue { } } } - - // - // functions to manage multitasks - // - - // add a multitask by specifying the id of all subtask (subtask is a server_task) - void add_multitask(int id_multi, std::vector & sub_ids) { - std::lock_guard lock(mutex_tasks); - server_task_multi multi; - multi.id = id_multi; - std::copy(sub_ids.begin(), sub_ids.end(), std::inserter(multi.subtasks_remaining, multi.subtasks_remaining.end())); - queue_multitasks.push_back(multi); - } - - // updatethe remaining subtasks, while appending results to multitask - void update_multitask(int id_multi, int id_sub, server_task_result & result) { - std::lock_guard lock(mutex_tasks); - for (auto & multitask : queue_multitasks) { - if (multitask.id == id_multi) { - multitask.subtasks_remaining.erase(id_sub); - multitask.results.push_back(result); - } - } - } }; struct server_response { - typedef std::function callback_multitask_t; - callback_multitask_t callback_update_multitask; - // for keeping track of all tasks waiting for the result - std::set waiting_task_ids; + std::unordered_set waiting_task_ids; // the main result queue std::vector queue_results; @@ -563,6 +534,12 @@ struct server_response { waiting_task_ids.insert(id_task); } + void add_waiting_tasks(const std::vector & tasks) { + for (const auto & t : tasks) { + add_waiting_task_id(t.id); + } + } + // when the request is finished, we can remove task associated with it void remove_waiting_task_id(int id_task) { LOG_VERBOSE("remove waiting for task id", {{"id_task", id_task}}); @@ -571,8 +548,8 @@ struct server_response { waiting_task_ids.erase(id_task); } - // This function blocks the thread until there is a response for this id_task - server_task_result recv(int id_task) { + // This function blocks the thread until there is a response for one of the id_tasks + server_task_result recv(const std::unordered_set & id_tasks) { while (true) { std::unique_lock lock(mutex_results); condition_results.wait(lock, [&]{ @@ -580,8 +557,7 @@ struct server_response { }); for (int i = 0; i < (int) queue_results.size(); i++) { - if (queue_results[i].id == id_task) { - assert(queue_results[i].id_multi == -1); + if (id_tasks.find(queue_results[i].id) != id_tasks.end()) { server_task_result res = queue_results[i]; queue_results.erase(queue_results.begin() + i); return res; @@ -592,28 +568,21 @@ struct server_response { // should never reach here } - // Register the function to update multitask - void on_multitask_update(callback_multitask_t callback) { - callback_update_multitask = std::move(callback); + // single-task version of recv() + server_task_result recv(int id_task) { + std::unordered_set id_tasks = {id_task}; + return recv(id_tasks); } // Send a new result to a waiting id_task - void send(server_task_result result) { + void send(server_task_result & result) { LOG_VERBOSE("send new result", {{"id_task", result.id}}); std::unique_lock lock(mutex_results); for (const auto & id_task : waiting_task_ids) { - // LOG_TEE("waiting task id %i \n", id_task); - // for now, tasks that have associated parent multitasks just get erased once multitask picks up the result - if (result.id_multi == id_task) { - LOG_VERBOSE("callback_update_multitask", {{"id_task", id_task}}); - callback_update_multitask(id_task, result.id, result); - continue; - } - if (result.id == id_task) { LOG_VERBOSE("queue_results.push_back", {{"id_task", id_task}}); - queue_results.push_back(result); + queue_results.push_back(std::move(result)); condition_results.notify_all(); return; } @@ -966,7 +935,7 @@ struct server_context { slot.params.input_suffix = json_value(data, "input_suffix", default_params.input_suffix); // get prompt - if (!task.infill) { + if (task.cmpl_type != SERVER_TASK_CMPL_TYPE_INFILL) { const auto & prompt = data.find("prompt"); if (prompt == data.end()) { send_error(task, "\"prompt\" must be provided", ERROR_TYPE_INVALID_REQUEST); @@ -1359,23 +1328,21 @@ struct server_context { } void send_error(const server_task & task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) { - send_error(task.id, task.id_multi, error, type); + send_error(task.id, error, type); } void send_error(const server_slot & slot, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) { - send_error(slot.id_task, slot.id_multi, error, type); + send_error(slot.id_task, error, type); } - void send_error(const int id_task, const int id_multi, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) { + void send_error(const int id_task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) { LOG_ERROR("task error", { - {"id_multi", id_multi}, {"id_task", id_task}, {"error", error}, }); server_task_result res; res.id = id_task; - res.id_multi = id_multi; res.stop = false; res.error = true; res.data = format_error_response(error, type); @@ -1386,14 +1353,14 @@ struct server_context { void send_partial_response(server_slot & slot, completion_token_output tkn) { server_task_result res; res.id = slot.id_task; - res.id_multi = slot.id_multi; res.error = false; res.stop = false; res.data = json { {"content", tkn.text_to_send}, {"stop", false}, {"id_slot", slot.id}, - {"multimodal", false} + {"multimodal", false}, + {"index", slot.index}, }; if (slot.sparams.n_probs > 0) { @@ -1423,7 +1390,6 @@ struct server_context { void send_final_response(const server_slot & slot) { server_task_result res; res.id = slot.id_task; - res.id_multi = slot.id_multi; res.error = false; res.stop = true; res.data = json { @@ -1441,7 +1407,8 @@ struct server_context { {"stopped_limit", slot.stopped_limit}, {"stopping_word", slot.stopping_word}, {"tokens_cached", slot.n_past}, - {"timings", slot.get_formated_timings()} + {"timings", slot.get_formated_timings()}, + {"index", slot.index}, }; if (slot.sparams.n_probs > 0) { @@ -1473,7 +1440,6 @@ struct server_context { void send_embedding(const server_slot & slot, const llama_batch & batch) { server_task_result res; res.id = slot.id_task; - res.id_multi = slot.id_multi; res.error = false; res.stop = true; @@ -1508,82 +1474,127 @@ struct server_context { res.data = json { {"embedding", embd_res}, + {"index", slot.index}, }; } queue_results.send(res); } - void request_completion(int id_task, int id_multi, json data, bool infill, bool embedding) { - server_task task; - task.id = id_task; - task.id_multi = id_multi; - task.id_target = 0; - task.data = std::move(data); - task.infill = infill; - task.embedding = embedding; - task.type = SERVER_TASK_TYPE_COMPLETION; + // + // Functions to create new task(s) and receive result(s) + // - // when a completion task's prompt array is not a singleton, we split it into multiple requests - // otherwise, it's a single-prompt task, we actually queue it - // if there's numbers in the prompt array it will be treated as an array of tokens - if (task.data.count("prompt") != 0 && task.data.at("prompt").size() > 1) { - bool numbers = false; - for (const auto & e : task.data.at("prompt")) { - if (e.is_number()) { - numbers = true; + std::vector create_tasks_cmpl(json data, server_task_cmpl_type cmpl_type) { + std::vector tasks; + auto create_task = [&](json & task_data, bool replace_prompt, json prompt) { + server_task task; + task.id = queue_tasks.get_new_id(); + task.cmpl_type = cmpl_type; + task.type = SERVER_TASK_TYPE_COMPLETION; + if (replace_prompt) { + task.data = task_data; + task.data["prompt"] = prompt; + } else { + task.data = std::move(task_data); + } + tasks.push_back(std::move(task)); + }; + + static constexpr const char * error_msg = "\"prompt\" must be a string, an array of token ids or an array of prompts"; + if (!data.contains("prompt")) { + throw std::runtime_error(error_msg); + } + + json prompt = data.at("prompt"); + + // if the prompt is a singleton (i.e. a string or a list of tokens), we only need to create single task + if (prompt.is_string() || json_is_array_of_numbers(prompt)) { + data["index"] = 0; + create_task(data, false, nullptr); + } + // otherwise, it's a multiple-prompt task, we break it into smaller tasks + else if (prompt.is_array()) { + std::vector prompts = prompt; + for (size_t i = 0; i < prompts.size(); i++) { + const auto & e = prompts[i]; + if (e.is_string() || json_is_array_of_numbers(e)) { + data["index"] = i; + create_task(data, true, e); + } else { + throw std::runtime_error(error_msg); + } + } + } + // invalid case + else { + throw std::runtime_error(error_msg); + } + + return tasks; + } + + void cancel_tasks(const std::unordered_set & id_tasks) { + std::vector cancel_tasks; + cancel_tasks.reserve(id_tasks.size()); + for (const auto & id_task : id_tasks) { + LOG_VERBOSE("cancel task", {{"id_task", id_task}}); + server_task task; + task.type = SERVER_TASK_TYPE_CANCEL; + task.id_target = id_task; + cancel_tasks.push_back(task); + queue_results.remove_waiting_task_id(id_task); + } + // push to beginning of the queue, so it has highest priority + queue_tasks.post(cancel_tasks, true); + } + + // receive the results from task(s) created by create_tasks_cmpl + void receive_cmpl_results(const std::unordered_set & id_tasks, std::function&)> result_handler, std::function error_handler) { + // TODO: currently, there is no way to detect the client has cancelled the request + std::vector results(id_tasks.size()); + for (size_t i = 0; i < id_tasks.size(); i++) { + server_task_result result = queue_results.recv(id_tasks); + + if (result.error) { + error_handler(result.data); + cancel_tasks(id_tasks); + break; + } + + size_t idx = result.data["index"]; + results[idx] = result; + } + result_handler(results); + } + + // receive the results from task(s) created by create_tasks_cmpl, in stream mode + void receive_cmpl_results_stream(const std::unordered_set & id_tasks, std::function result_handler, std::function error_handler) { + size_t n_finished = 0; + while (true) { + server_task_result result = queue_results.recv(id_tasks); + if (!result_handler(result)) { + cancel_tasks(id_tasks); + break; + } + + if (result.error) { + error_handler(result.data); + cancel_tasks(id_tasks); + break; + } + + if (result.stop) { + if (++n_finished == id_tasks.size()) { break; } } - - // NOTE: split_multiprompt_task() does not handle a mix of strings and numbers, - // it will completely stall the server. I don't know where the bug for this is. - // - // if there are numbers, it needs to be treated like a single prompt, - // queue_tasks handles a mix of strings and numbers just fine. - if (numbers) { - queue_tasks.post(task); - } else { - split_multiprompt_task(id_task, task); - } - } else { - queue_tasks.post(task); } } - void request_cancel(int id_task) { - server_task task; - task.type = SERVER_TASK_TYPE_CANCEL; - task.id_target = id_task; - - queue_tasks.post(task); - } - - void split_multiprompt_task(int id_multi, const server_task & multiprompt_task) { - const int prompt_count = multiprompt_task.data.at("prompt").size(); - if (prompt_count <= 1) { - send_error(multiprompt_task, "error while handling multiple prompts"); - return; - } - - // generate all the ID for subtask - std::vector subtask_ids(prompt_count); - for (int i = 0; i < prompt_count; i++) { - subtask_ids[i] = queue_tasks.get_new_id(); - } - - // queue up the multitask so we can track its subtask progression - queue_tasks.add_multitask(id_multi, subtask_ids); - - // add subtasks - for (int i = 0; i < prompt_count; i++) { - json subtask_data = multiprompt_task.data; - subtask_data["prompt"] = subtask_data.at("prompt")[i]; - - // subtasks inherit everything else (infill mode, embedding mode, etc.) - request_completion(subtask_ids[i], id_multi, subtask_data, multiprompt_task.infill, multiprompt_task.embedding); - } - } + // + // Functions to process the task + // void process_single_task(const server_task & task) { switch (task.type) { @@ -1630,9 +1641,8 @@ struct server_context { slot->reset(); slot->id_task = task.id; - slot->id_multi = task.id_multi; - slot->infill = task.infill; - slot->embedding = task.embedding; + slot->cmpl_type = task.cmpl_type; + slot->index = json_value(task.data, "index", 0); if (!launch_slot_with_task(*slot, task)) { LOG_ERROR("error while launching slot", task.data); @@ -1699,7 +1709,6 @@ struct server_context { server_task_result res; res.id = task.id; - res.id_multi = task.id_multi; res.stop = true; res.error = false; res.data = { @@ -1861,26 +1870,6 @@ struct server_context { } } - void on_finish_multitask(const server_task_multi & multitask) { - // all subtasks done == multitask is done - server_task_result result; - result.id = multitask.id; - result.stop = true; - result.error = false; - - // collect json results into one json result - std::vector result_jsons; - for (const auto & subres : multitask.results) { - result_jsons.push_back(subres.data); - result.error = result.error && subres.error; - } - result.data = json { - { "results", result_jsons } - }; - - queue_results.send(result); - } - void update_slots() { if (system_need_update) { system_prompt_update(); @@ -2038,7 +2027,7 @@ struct server_context { slot.t_start_process_prompt = ggml_time_us(); slot.t_start_generation = 0; - if (slot.infill) { + if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_INFILL) { const bool add_bos = llama_add_bos_token(model); bool suff_rm_leading_spc = true; if (params.input_suffix.find_first_of(' ') == 0 && params.input_suffix.size() > 1) { @@ -2101,7 +2090,7 @@ struct server_context { continue; } - if (slot.embedding) { + if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING) { // this prompt is too large to process - discard it if (slot.n_prompt_tokens > n_ubatch) { slot.state = SLOT_STATE_PROCESSING; @@ -2184,7 +2173,7 @@ struct server_context { slot.n_prompt_tokens_processed = 0; } - if (slot.embedding) { + if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING) { // cannot fit the prompt in the current batch - will try next iter if (batch.n_tokens + slot.n_prompt_tokens > n_batch) { continue; @@ -2192,7 +2181,7 @@ struct server_context { } // check that we are in the right batch_type, if not defer the slot - bool slot_type = slot.embedding ? 1 : 0; + bool slot_type = slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING ? 1 : 0; if (batch_type == -1) { batch_type = slot_type; } else if (batch_type != slot_type) { @@ -2385,7 +2374,7 @@ struct server_context { } // prompt evaluated for embedding - if (slot.embedding) { + if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING) { send_embedding(slot, batch_view); slot.release(); slot.i_batch = -1; @@ -2576,6 +2565,11 @@ int main(int argc, char ** argv) { res.status = json_value(error_data, "code", 500); }; + auto res_ok = [](httplib::Response & res, json data) { + res.set_content(data.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON); + res.status = 200; + }; + svr->set_exception_handler([&res_error](const httplib::Request &, httplib::Response & res, std::exception_ptr ep) { std::string message; try { @@ -2623,7 +2617,7 @@ int main(int argc, char ** argv) { auto middleware_validate_api_key = [¶ms, &res_error](const httplib::Request & req, httplib::Response & res) { // TODO: should we apply API key to all endpoints, including "/health" and "/models"? - static const std::set protected_endpoints = { + static const std::unordered_set protected_endpoints = { "/props", "/completion", "/completions", @@ -2695,7 +2689,7 @@ int main(int argc, char ** argv) { const auto handle_health = [&](const httplib::Request &, httplib::Response & res) { // error and loading states are handled by middleware json health = {{"status", "ok"}}; - res.set_content(health.dump(), "application/json"); + res_ok(res, health); }; const auto handle_slots = [&](const httplib::Request & req, httplib::Response & res) { @@ -2707,8 +2701,6 @@ int main(int argc, char ** argv) { // request slots data using task queue server_task task; task.id = ctx_server.queue_tasks.get_new_id(); - task.id_multi = -1; - task.id_target = -1; task.type = SERVER_TASK_TYPE_METRICS; ctx_server.queue_results.add_waiting_task_id(task.id); @@ -2727,8 +2719,7 @@ int main(int argc, char ** argv) { } } - res.set_content(result.data.at("slots").dump(), MIMETYPE_JSON); - res.status = 200; // HTTP OK + res_ok(res, result.data.at("slots")); }; const auto handle_metrics = [&](const httplib::Request &, httplib::Response & res) { @@ -2740,7 +2731,6 @@ int main(int argc, char ** argv) { // request slots data using task queue server_task task; task.id = ctx_server.queue_tasks.get_new_id(); - task.id_multi = -1; task.id_target = -1; task.type = SERVER_TASK_TYPE_METRICS; task.data.push_back({{"reset_bucket", true}}); @@ -2832,7 +2822,7 @@ int main(int argc, char ** argv) { res.status = 200; // HTTP OK }; - const auto handle_slots_save = [&ctx_server, &res_error, ¶ms](const httplib::Request & req, httplib::Response & res, int id_slot) { + const auto handle_slots_save = [&ctx_server, &res_error, &res_ok, ¶ms](const httplib::Request & req, httplib::Response & res, int id_slot) { json request_data = json::parse(req.body); std::string filename = request_data.at("filename"); if (!fs_validate_filename(filename)) { @@ -2858,11 +2848,11 @@ int main(int argc, char ** argv) { if (result.error) { res_error(res, result.data); } else { - res.set_content(result.data.dump(), MIMETYPE_JSON); + res_ok(res, result.data); } }; - const auto handle_slots_restore = [&ctx_server, &res_error, ¶ms](const httplib::Request & req, httplib::Response & res, int id_slot) { + const auto handle_slots_restore = [&ctx_server, &res_error, &res_ok, ¶ms](const httplib::Request & req, httplib::Response & res, int id_slot) { json request_data = json::parse(req.body); std::string filename = request_data.at("filename"); if (!fs_validate_filename(filename)) { @@ -2888,11 +2878,11 @@ int main(int argc, char ** argv) { if (result.error) { res_error(res, result.data); } else { - res.set_content(result.data.dump(), MIMETYPE_JSON); + res_ok(res, result.data); } }; - const auto handle_slots_erase = [&ctx_server, &res_error](const httplib::Request & /* req */, httplib::Response & res, int id_slot) { + const auto handle_slots_erase = [&ctx_server, &res_error, &res_ok](const httplib::Request & /* req */, httplib::Response & res, int id_slot) { server_task task; task.type = SERVER_TASK_TYPE_SLOT_ERASE; task.data = { @@ -2908,11 +2898,16 @@ int main(int argc, char ** argv) { if (result.error) { res_error(res, result.data); } else { - res.set_content(result.data.dump(), MIMETYPE_JSON); + res_ok(res, result.data); } }; - const auto handle_slots_action = [&res_error, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) { + const auto handle_slots_action = [¶ms, &res_error, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) { + if (params.slot_save_path.empty()) { + res_error(res, format_error_response("This server does not support slots action. Start it with `--slot-save-path`", ERROR_TYPE_NOT_SUPPORTED)); + return; + } + std::string id_slot_str = req.path_params.at("id_slot"); int id_slot; @@ -2936,7 +2931,7 @@ int main(int argc, char ** argv) { } }; - const auto handle_props = [&ctx_server](const httplib::Request &, httplib::Response & res) { + const auto handle_props = [&ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) { std::string template_key = "tokenizer.chat_template", curr_tmpl; int32_t tlen = llama_model_meta_val_str(ctx_server.model, template_key.c_str(), nullptr, 0); if (tlen > 0) { @@ -2952,85 +2947,107 @@ int main(int argc, char ** argv) { { "chat_template", curr_tmpl.c_str() } }; - res.set_content(data.dump(), MIMETYPE_JSON); + res_ok(res, data); }; - const auto handle_completions = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) { + const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_cmpl_type cmpl_type, json & data, httplib::Response & res) { if (ctx_server.params.embedding) { res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); return; } - json data = json::parse(req.body); + std::vector tasks = ctx_server.create_tasks_cmpl(data, cmpl_type); + ctx_server.queue_results.add_waiting_tasks(tasks); + ctx_server.queue_tasks.post(tasks); - const int id_task = ctx_server.queue_tasks.get_new_id(); + bool stream = json_value(data, "stream", false); + const auto task_ids = server_task::get_list_id(tasks); - ctx_server.queue_results.add_waiting_task_id(id_task); - ctx_server.request_completion(id_task, -1, data, false, false); - - if (!json_value(data, "stream", false)) { - server_task_result result = ctx_server.queue_results.recv(id_task); - if (!result.error && result.stop) { - res.set_content(result.data.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON); - } else { - res_error(res, result.data); - } - - ctx_server.queue_results.remove_waiting_task_id(id_task); - } else { - const auto chunked_content_provider = [id_task, &ctx_server](size_t, httplib::DataSink & sink) { - while (true) { - server_task_result result = ctx_server.queue_results.recv(id_task); - if (!result.error) { - const std::string str = - "data: " + - result.data.dump(-1, ' ', false, json::error_handler_t::replace) + - "\n\n"; - - LOG_VERBOSE("data stream", { - { "to_send", str } - }); - - if (!sink.write(str.c_str(), str.size())) { - ctx_server.queue_results.remove_waiting_task_id(id_task); - return false; - } - - if (result.stop) { - break; - } - } else { - const std::string str = - "error: " + - result.data.dump(-1, ' ', false, json::error_handler_t::replace) + - "\n\n"; - - LOG_VERBOSE("data stream", { - { "to_send", str } - }); - - if (!sink.write(str.c_str(), str.size())) { - ctx_server.queue_results.remove_waiting_task_id(id_task); - return false; - } - - break; + if (!stream) { + ctx_server.receive_cmpl_results(task_ids, [&](std::vector & results) { + if (results.size() == 1) { + // single result + res_ok(res, results[0].data); + } else { + // multiple results (multitask) + json arr = json::array(); + for (const auto & res : results) { + arr.push_back(res.data); } + res_ok(res, arr); } - - ctx_server.queue_results.remove_waiting_task_id(id_task); + }, [&](json error_data) { + res_error(res, error_data); + }); + } else { + const auto chunked_content_provider = [task_ids, &ctx_server](size_t, httplib::DataSink & sink) { + ctx_server.receive_cmpl_results_stream(task_ids, [&](server_task_result result) -> bool { + return server_sent_event(sink, "data", result.data); + }, [&](json error_data) { + server_sent_event(sink, "error", error_data); + }); sink.done(); + return false; + }; + res.set_chunked_content_provider("text/event-stream", chunked_content_provider); + } + }; + const auto handle_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) { + json data = json::parse(req.body); + return handle_completions_generic(SERVER_TASK_CMPL_TYPE_NORMAL, data, res); + }; + + const auto handle_infill = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) { + json data = json::parse(req.body); + return handle_completions_generic(SERVER_TASK_CMPL_TYPE_INFILL, data, res); + }; + + // TODO: maybe merge this function with "handle_completions_generic" + const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) { + if (ctx_server.params.embedding) { + res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); + return; + } + + json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template); + + std::vector tasks = ctx_server.create_tasks_cmpl(data, SERVER_TASK_CMPL_TYPE_NORMAL); + ctx_server.queue_results.add_waiting_tasks(tasks); + ctx_server.queue_tasks.post(tasks); + + bool stream = json_value(data, "stream", false); + const auto task_ids = server_task::get_list_id(tasks); + const auto completion_id = gen_chatcmplid(); + + if (!stream) { + ctx_server.receive_cmpl_results(task_ids, [&](std::vector & results) { + // multitask is never support in chat completion, there is only one result + json result_oai = format_final_response_oaicompat(data, results[0].data, completion_id); + res_ok(res, result_oai); + }, [&](json error_data) { + res_error(res, error_data); + }); + } else { + const auto chunked_content_provider = [task_ids, &ctx_server, completion_id](size_t, httplib::DataSink & sink) { + ctx_server.receive_cmpl_results_stream(task_ids, [&](server_task_result result) -> bool { + std::vector result_array = format_partial_response_oaicompat(result.data, completion_id); + for (auto & event_data : result_array) { + if (event_data.empty()) { + continue; // skip the stop token + } + if (!server_sent_event(sink, "data", event_data)) { + return false; // connection is closed + } + } + return true; // ok + }, [&](json error_data) { + server_sent_event(sink, "error", error_data); + }); + sink.done(); return true; }; - - auto on_complete = [id_task, &ctx_server] (bool) { - // cancel - ctx_server.request_cancel(id_task); - ctx_server.queue_results.remove_waiting_task_id(id_task); - }; - - res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); + res.set_chunked_content_provider("text/event-stream", chunked_content_provider); } }; @@ -3051,145 +3068,7 @@ int main(int argc, char ** argv) { res.set_content(models.dump(), MIMETYPE_JSON); }; - const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error](const httplib::Request & req, httplib::Response & res) { - if (ctx_server.params.embedding) { - res_error(res, format_error_response("This server does not support chat completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); - return; - } - json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template); - - const int id_task = ctx_server.queue_tasks.get_new_id(); - - ctx_server.queue_results.add_waiting_task_id(id_task); - ctx_server.request_completion(id_task, -1, data, false, false); - - const auto completion_id = gen_chatcmplid(); - if (!json_value(data, "stream", false)) { - server_task_result result = ctx_server.queue_results.recv(id_task); - - if (!result.error && result.stop) { - json result_oai = format_final_response_oaicompat(data, result.data, completion_id); - - res.set_content(result_oai.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON); - } else { - res_error(res, result.data); - } - ctx_server.queue_results.remove_waiting_task_id(id_task); - } else { - const auto chunked_content_provider = [id_task, &ctx_server, completion_id](size_t, httplib::DataSink & sink) { - while (true) { - server_task_result result = ctx_server.queue_results.recv(id_task); - if (!result.error) { - std::vector result_array = format_partial_response_oaicompat(result.data, completion_id); - - for (auto it = result_array.begin(); it != result_array.end(); ++it) { - if (!it->empty()) { - const std::string str = - "data: " + - it->dump(-1, ' ', false, json::error_handler_t::replace) + - "\n\n"; - LOG_VERBOSE("data stream", {{"to_send", str}}); - if (!sink.write(str.c_str(), str.size())) { - ctx_server.queue_results.remove_waiting_task_id(id_task); - return false; - } - } - } - if (result.stop) { - break; - } - } else { - const std::string str = - "error: " + - result.data.dump(-1, ' ', false, json::error_handler_t::replace) + - "\n\n"; - LOG_VERBOSE("data stream", {{"to_send", str}}); - if (!sink.write(str.c_str(), str.size())) { - ctx_server.queue_results.remove_waiting_task_id(id_task); - return false; - } - break; - } - } - sink.done(); - ctx_server.queue_results.remove_waiting_task_id(id_task); - return true; - }; - - auto on_complete = [id_task, &ctx_server](bool) { - // cancel request - ctx_server.request_cancel(id_task); - ctx_server.queue_results.remove_waiting_task_id(id_task); - }; - - res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); - } - }; - - const auto handle_infill = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) { - if (ctx_server.params.embedding) { - res_error(res, format_error_response("This server does not support infill. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); - return; - } - - json data = json::parse(req.body); - - const int id_task = ctx_server.queue_tasks.get_new_id(); - - ctx_server.queue_results.add_waiting_task_id(id_task); - ctx_server.request_completion(id_task, -1, data, true, false); - - if (!json_value(data, "stream", false)) { - server_task_result result = ctx_server.queue_results.recv(id_task); - if (!result.error && result.stop) { - res.set_content(result.data.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON); - } else { - res_error(res, result.data); - } - - ctx_server.queue_results.remove_waiting_task_id(id_task); - } else { - const auto chunked_content_provider = [id_task, &ctx_server](size_t, httplib::DataSink & sink) { - while (true) { - server_task_result result = ctx_server.queue_results.recv(id_task); - if (!result.error) { - const std::string str = - "data: " + - result.data.dump(-1, ' ', false, json::error_handler_t::replace) + - "\n\n"; - - LOG_VERBOSE("data stream", { - { "to_send", str } - }); - - if (!sink.write(str.c_str(), str.size())) { - ctx_server.queue_results.remove_waiting_task_id(id_task); - return false; - } - - if (result.stop) { - break; - } - } else { - break; - } - } - - ctx_server.queue_results.remove_waiting_task_id(id_task); - sink.done(); - - return true; - }; - - auto on_complete = [id_task, &ctx_server] (bool) { - ctx_server.request_cancel(id_task); - }; - - res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); - } - }; - - const auto handle_tokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) { + const auto handle_tokenize = [&ctx_server, &res_ok](const httplib::Request & req, httplib::Response & res) { const json body = json::parse(req.body); std::vector tokens; @@ -3198,10 +3077,10 @@ int main(int argc, char ** argv) { tokens = ctx_server.tokenize(body.at("content"), add_special); } const json data = format_tokenizer_response(tokens); - return res.set_content(data.dump(), MIMETYPE_JSON); + res_ok(res, data); }; - const auto handle_detokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) { + const auto handle_detokenize = [&ctx_server, &res_ok](const httplib::Request & req, httplib::Response & res) { const json body = json::parse(req.body); std::string content; @@ -3211,10 +3090,10 @@ int main(int argc, char ** argv) { } const json data = format_detokenized_response(content); - return res.set_content(data.dump(), MIMETYPE_JSON); + res_ok(res, data); }; - const auto handle_embeddings = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) { + const auto handle_embeddings = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) { const json body = json::parse(req.body); bool is_openai = false; @@ -3232,35 +3111,35 @@ int main(int argc, char ** argv) { } // create and queue the task - json responses; + json responses = json::array(); + bool error = false; { - const int id_task = ctx_server.queue_tasks.get_new_id(); - ctx_server.queue_results.add_waiting_task_id(id_task); - ctx_server.request_completion(id_task, -1, {{"prompt", prompt}}, false, true); + std::vector tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_EMBEDDING); + ctx_server.queue_results.add_waiting_tasks(tasks); + ctx_server.queue_tasks.post(tasks); // get the result - server_task_result result = ctx_server.queue_results.recv(id_task); - ctx_server.queue_results.remove_waiting_task_id(id_task); - if (!result.error) { - if (result.data.count("results")) { - // result for multi-task - responses = result.data.at("results"); - } else { - // result for single task - responses = std::vector{result.data}; + std::unordered_set task_ids = server_task::get_list_id(tasks); + + ctx_server.receive_cmpl_results(task_ids, [&](std::vector & results) { + for (const auto & res : results) { + responses.push_back(res.data); } - } else { - // error received, ignore everything else - res_error(res, result.data); - return; - } + }, [&](json error_data) { + res_error(res, error_data); + error = true; + }); + } + + if (error) { + return; } // write JSON response json root = is_openai ? format_embeddings_response_oaicompat(body, responses) : responses[0]; - return res.set_content(root.dump(), MIMETYPE_JSON); + res_ok(res, root); }; const auto handle_lora_adapters_list = [&](const httplib::Request &, httplib::Response & res) { @@ -3273,7 +3152,7 @@ int main(int argc, char ** argv) { {"scale", la.scale}, }); } - res.set_content(result.dump(), MIMETYPE_JSON); + res_ok(res, result); res.status = 200; // HTTP OK }; @@ -3305,7 +3184,7 @@ int main(int argc, char ** argv) { server_task_result result = ctx_server.queue_results.recv(id_task); ctx_server.queue_results.remove_waiting_task_id(id_task); - res.set_content(result.data.dump(), MIMETYPE_JSON); + res_ok(res, result.data); res.status = 200; // HTTP OK }; @@ -3366,10 +3245,7 @@ int main(int argc, char ** argv) { svr->Post("/lora-adapters", handle_lora_adapters_apply); // Save & load slots svr->Get ("/slots", handle_slots); - if (!params.slot_save_path.empty()) { - // only enable slot endpoints if slot_save_path is set - svr->Post("/slots/:id_slot", handle_slots_action); - } + svr->Post("/slots/:id_slot", handle_slots_action); // // Start the server @@ -3433,17 +3309,8 @@ int main(int argc, char ** argv) { ctx_server.queue_tasks.on_new_task(std::bind( &server_context::process_single_task, &ctx_server, std::placeholders::_1)); - ctx_server.queue_tasks.on_finish_multitask(std::bind( - &server_context::on_finish_multitask, &ctx_server, std::placeholders::_1)); ctx_server.queue_tasks.on_update_slots(std::bind( &server_context::update_slots, &ctx_server)); - ctx_server.queue_results.on_multitask_update(std::bind( - &server_queue::update_multitask, - &ctx_server.queue_tasks, - std::placeholders::_1, - std::placeholders::_2, - std::placeholders::_3 - )); shutdown_handler = [&](int) { ctx_server.queue_tasks.terminate(); diff --git a/examples/server/tests/features/parallel.feature b/examples/server/tests/features/parallel.feature index 6cd306a2b..96a96d6f8 100644 --- a/examples/server/tests/features/parallel.feature +++ b/examples/server/tests/features/parallel.feature @@ -52,8 +52,8 @@ Feature: Parallel Then all prompts are predicted with tokens Examples: | streaming | n_predict | - | disabled | 128 | - | enabled | 64 | + | disabled | 200 | + | enabled | 200 | Scenario Outline: Multi users OAI completions compatibility no v1 Given a system prompt You are a writer. diff --git a/examples/server/tests/features/steps/steps.py b/examples/server/tests/features/steps/steps.py index 1ba7b60b6..1864a694f 100644 --- a/examples/server/tests/features/steps/steps.py +++ b/examples/server/tests/features/steps/steps.py @@ -818,7 +818,7 @@ async def concurrent_requests(context, f_completion, *args, **kwargs): for prompt_no in range(context.n_prompts): shifted_args = [context.prompts.pop(), seeds[prompt_no], *args] context.concurrent_tasks.append(asyncio.create_task(f_completion(*shifted_args, **kwargs))) - await asyncio.sleep(0.1) + await asyncio.sleep(0.01) @step('the slot {slot_id:d} is saved with filename "{filename}"') diff --git a/examples/server/tests/features/wrong_usages.feature b/examples/server/tests/features/wrong_usages.feature index cf14b3b44..61d5f315e 100644 --- a/examples/server/tests/features/wrong_usages.feature +++ b/examples/server/tests/features/wrong_usages.feature @@ -8,9 +8,12 @@ Feature: Wrong usage of llama.cpp server Scenario: Infinite loop Given a server listening on localhost:8080 And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models + And 42 as server seed + And 2048 KV cache size # Uncomment below to fix the issue #And 64 server max tokens to predict Then the server is starting + Then the server is healthy Given a prompt: """ Go to: infinite loop diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index e6a1f0697..edfce65b6 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -3,6 +3,14 @@ #include "llama.h" #include "common.h" +#ifndef NDEBUG +// crash the server in debug mode, otherwise send an http 500 error +#define CPPHTTPLIB_NO_EXCEPTIONS 1 +#endif +// increase max payload length to allow use of larger context size +#define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 1048576 +#include "httplib.h" + // Change JSON_ASSERT from assert() to GGML_ASSERT: #define JSON_ASSERT GGML_ASSERT #include "json.hpp" @@ -279,6 +287,18 @@ static size_t find_partial_stop_string(const std::string &stop, const std::strin return std::string::npos; } +static bool json_is_array_of_numbers(json data) { + if (data.is_array()) { + for (const auto & e : data) { + if (!e.is_number()) { + return false; + } + } + return true; + } + return false; +} + // TODO: reuse llama_detokenize template static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) { @@ -343,6 +363,19 @@ static json probs_vector_to_json(const llama_context * ctx, const std::vector Date: Mon, 2 Sep 2024 08:25:30 -0700 Subject: [PATCH 10/22] ggml : add pthread includes on FreeBSD (#9258) --- ggml/src/ggml.c | 3 +++ 1 file changed, 3 insertions(+) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 83140d757..6e2ebf283 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -147,6 +147,9 @@ static int sched_yield (void) { #include #include #include +#if defined(__FreeBSD__) +#include +#endif typedef void * thread_ret_t; From 048de848ee4baad4531fcd6239438f5d55be365c Mon Sep 17 00:00:00 2001 From: slaren Date: Mon, 2 Sep 2024 18:11:13 +0200 Subject: [PATCH 11/22] docker : fix missing binaries in full-cuda image (#9278) --- .devops/full-cuda.Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.devops/full-cuda.Dockerfile b/.devops/full-cuda.Dockerfile index b8a354246..d5acd35e2 100644 --- a/.devops/full-cuda.Dockerfile +++ b/.devops/full-cuda.Dockerfile @@ -27,7 +27,7 @@ RUN if [ "${CUDA_DOCKER_ARCH}" != "default" ]; then \ export CMAKE_ARGS="-DCMAKE_CUDA_ARCHITECTURES=${CUDA_DOCKER_ARCH}"; \ fi && \ cmake -B build -DGGML_CUDA=ON -DLLAMA_CURL=ON ${CMAKE_ARGS} -DCMAKE_EXE_LINKER_FLAGS=-Wl,--allow-shlib-undefined . && \ - cmake --build build --config Release --target llama-cli -j$(nproc) && \ + cmake --build build --config Release -j$(nproc) && \ cp build/bin/* . ENTRYPOINT ["/app/.devops/tools.sh"] From f1485161e58d562099dd050f8ac3a9ea9f4cd765 Mon Sep 17 00:00:00 2001 From: Zhenwei Jin <109658203+kylo5aby@users.noreply.github.com> Date: Tue, 3 Sep 2024 01:53:23 +0800 Subject: [PATCH 12/22] src: make tail invalid when kv cell is intersection for mamba (#9249) --- src/llama.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/llama.cpp b/src/llama.cpp index 4e203471c..883559716 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -3810,7 +3810,8 @@ static bool llama_kv_cache_seq_rm( if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) { return false; } - if (p0 <= cell.pos && p1 < cell.pos) { + // invalidate tails which will be cleared + if (p0 <= cell.pos && cell.pos < p1) { tail_id = -1; } } From 48baa61eccdca9205daf8d620ba28055c2347b64 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Mon, 2 Sep 2024 22:08:38 +0200 Subject: [PATCH 13/22] server : test script : add timeout for all requests (#9282) --- .../server/tests/features/parallel.feature | 4 +-- examples/server/tests/features/steps/steps.py | 36 +++++++++---------- 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/examples/server/tests/features/parallel.feature b/examples/server/tests/features/parallel.feature index 96a96d6f8..6cd306a2b 100644 --- a/examples/server/tests/features/parallel.feature +++ b/examples/server/tests/features/parallel.feature @@ -52,8 +52,8 @@ Feature: Parallel Then all prompts are predicted with tokens Examples: | streaming | n_predict | - | disabled | 200 | - | enabled | 200 | + | disabled | 128 | + | enabled | 64 | Scenario Outline: Multi users OAI completions compatibility no v1 Given a system prompt You are a writer. diff --git a/examples/server/tests/features/steps/steps.py b/examples/server/tests/features/steps/steps.py index 1864a694f..18daad476 100644 --- a/examples/server/tests/features/steps/steps.py +++ b/examples/server/tests/features/steps/steps.py @@ -23,6 +23,8 @@ from prometheus_client import parser # pyright: reportRedeclaration=false +DEFAULT_TIMEOUT_SECONDS = aiohttp.ClientTimeout(total=600) + @step("a server listening on {server_fqdn}:{server_port}") def step_server_config(context, server_fqdn: str, server_port: str): context.server_fqdn = server_fqdn @@ -689,7 +691,7 @@ def step_tokenize_set_add_special(context): @async_run_until_complete async def step_tokenize(context): context.tokenized_text = context_text(context) - async with aiohttp.ClientSession() as session: + async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session: tokenize_args = { "content": context.tokenized_text, } @@ -706,7 +708,7 @@ async def step_tokenize(context): @async_run_until_complete async def step_detokenize(context): assert len(context.tokens) > 0 - async with aiohttp.ClientSession() as session: + async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session: async with session.post(f'{context.base_url}/detokenize', json={ "tokens": context.tokens, @@ -735,7 +737,7 @@ def step_strings_for_tokenization(context): @step('an OPTIONS request is sent from {origin}') @async_run_until_complete async def step_options_request(context, origin): - async with aiohttp.ClientSession() as session: + async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session: headers = {'Authorization': f'Bearer {context.user_api_key}', 'Origin': origin} async with session.options(f'{context.base_url}/v1/chat/completions', headers=headers) as response: @@ -751,7 +753,7 @@ def step_check_options_header_value(context, cors_header, cors_header_value): @step('prometheus metrics are exposed') @async_run_until_complete async def step_prometheus_metrics_exported(context): - async with aiohttp.ClientSession() as session: + async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session: async with await session.get(f'{context.base_url}/metrics') as metrics_response: assert metrics_response.status == 200 assert metrics_response.headers['Content-Type'] == "text/plain; version=0.0.4" @@ -824,7 +826,7 @@ async def concurrent_requests(context, f_completion, *args, **kwargs): @step('the slot {slot_id:d} is saved with filename "{filename}"') @async_run_until_complete async def step_save_slot(context, slot_id, filename): - async with aiohttp.ClientSession() as session: + async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session: async with session.post(f'{context.base_url}/slots/{slot_id}?action=save', json={"filename": filename}, headers={"Content-Type": "application/json"}) as response: @@ -834,7 +836,7 @@ async def step_save_slot(context, slot_id, filename): @step('the slot {slot_id:d} is restored with filename "{filename}"') @async_run_until_complete async def step_restore_slot(context, slot_id, filename): - async with aiohttp.ClientSession() as session: + async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session: async with session.post(f'{context.base_url}/slots/{slot_id}?action=restore', json={"filename": filename}, headers={"Content-Type": "application/json"}) as response: @@ -844,7 +846,7 @@ async def step_restore_slot(context, slot_id, filename): @step('the slot {slot_id:d} is erased') @async_run_until_complete async def step_erase_slot(context, slot_id): - async with aiohttp.ClientSession() as session: + async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session: async with session.post(f'{context.base_url}/slots/{slot_id}?action=erase', headers={"Content-Type": "application/json"}) as response: context.response = response @@ -853,7 +855,7 @@ async def step_erase_slot(context, slot_id): @step('switch {on_or_off} lora adapter {lora_id:d}') @async_run_until_complete async def toggle_lora_adapter(context, on_or_off: str, lora_id: int): - async with aiohttp.ClientSession() as session: + async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session: async with session.post(f'{context.base_url}/lora-adapters', json=[{'id': lora_id, 'scale': 1 if on_or_off == 'on' else 0}], headers={"Content-Type": "application/json"}) as response: @@ -889,7 +891,7 @@ async def request_completion(prompt, print(f"Set user_api_key: {user_api_key}") headers['Authorization'] = f'Bearer {user_api_key}' - async with aiohttp.ClientSession() as session: + async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session: async with session.post(f'{base_url}/completion', json={ "input_prefix": prompt_prefix, @@ -902,8 +904,7 @@ async def request_completion(prompt, "temperature": temperature if temperature is not None else 0.8, "n_probs": 2, }, - headers=headers, - timeout=3600) as response: + headers=headers) as response: if expect_api_error is None or not expect_api_error: assert response.status == 200 assert response.headers['Access-Control-Allow-Origin'] == origin @@ -961,7 +962,7 @@ async def oai_chat_completions(user_prompt, if async_client: origin = 'llama.cpp' headers = {'Authorization': f'Bearer {user_api_key}', 'Origin': origin} - async with aiohttp.ClientSession() as session: + async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session: async with session.post(f'{base_url}{base_path}', json=payload, headers=headers) as response: @@ -1048,7 +1049,7 @@ async def oai_chat_completions(user_prompt, async def request_embedding(content, seed, base_url=None) -> list[list[float]]: - async with aiohttp.ClientSession() as session: + async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session: async with session.post(f'{base_url}/embedding', json={ "content": content, @@ -1068,14 +1069,13 @@ async def request_oai_embeddings(input, seed, headers=[] if user_api_key is not None: headers = {'Authorization': f'Bearer {user_api_key}', 'Origin': origin} - async with aiohttp.ClientSession() as session: + async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session: async with session.post(f'{base_url}/v1/embeddings', json={ "input": input, "model": model, }, - headers=headers, - timeout=3600) as response: + headers=headers) as response: assert response.status == 200, f"received status code not expected: {response.status}" assert response.headers['Access-Control-Allow-Origin'] == origin assert response.headers['Content-Type'] == "application/json; charset=utf-8" @@ -1194,7 +1194,7 @@ async def wait_for_slots_status(context, if 'GITHUB_ACTIONS' in os.environ: timeout *= 2 - async with aiohttp.ClientSession() as session: + async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session: while True: async with await session.get(f'{base_url}/slots', params=params) as slots_response: status_code = slots_response.status @@ -1237,7 +1237,7 @@ def assert_embeddings(embeddings): async def request_slots_status(context, expected_slots): - async with aiohttp.ClientSession() as session: + async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session: async with await session.get(f'{context.base_url}/slots') as slots_response: assert slots_response.status == 200 slots = await slots_response.json() From b69a480af4db8ffd6cbb2efe7ed8132bc7d39073 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 3 Sep 2024 10:00:36 +0300 Subject: [PATCH 14/22] readme : refactor API section + remove old hot topics --- README.md | 24 +++--------------------- 1 file changed, 3 insertions(+), 21 deletions(-) diff --git a/README.md b/README.md index bb2b93a35..e30ab0c8c 100644 --- a/README.md +++ b/README.md @@ -10,32 +10,14 @@ Inference of Meta's [LLaMA](https://arxiv.org/abs/2302.13971) model (and others) in pure C/C++ -> [!IMPORTANT] -[2024 Jun 12] Binaries have been renamed w/ a `llama-` prefix. `main` is now `llama-cli`, `server` is `llama-server`, etc (https://github.com/ggerganov/llama.cpp/pull/7809) - ## Recent API changes -- [2024 Jun 26] The source code and CMake build scripts have been restructured https://github.com/ggerganov/llama.cpp/pull/8006 -- [2024 Apr 21] `llama_token_to_piece` can now optionally render special tokens https://github.com/ggerganov/llama.cpp/pull/6807 -- [2024 Apr 4] State and session file functions reorganized under `llama_state_*` https://github.com/ggerganov/llama.cpp/pull/6341 -- [2024 Mar 26] Logits and embeddings API updated for compactness https://github.com/ggerganov/llama.cpp/pull/6122 -- [2024 Mar 13] Add `llama_synchronize()` + `llama_context_params.n_ubatch` https://github.com/ggerganov/llama.cpp/pull/6017 -- [2024 Mar 8] `llama_kv_cache_seq_rm()` returns a `bool` instead of `void`, and new `llama_n_seq_max()` returns the upper limit of acceptable `seq_id` in batches (relevant when dealing with multiple sequences) https://github.com/ggerganov/llama.cpp/pull/5328 -- [2024 Mar 4] Embeddings API updated https://github.com/ggerganov/llama.cpp/pull/5796 -- [2024 Mar 3] `struct llama_context_params` https://github.com/ggerganov/llama.cpp/pull/5849 +- [Changelog for `libllama` API](https://github.com/ggerganov/llama.cpp/issues/9289) +- [Changelog for `llama-server` REST API](https://github.com/ggerganov/llama.cpp/issues/9291) ## Hot topics -- **`convert.py` has been deprecated and moved to `examples/convert_legacy_llama.py`, please use `convert_hf_to_gguf.py`** https://github.com/ggerganov/llama.cpp/pull/7430 -- Initial Flash-Attention support: https://github.com/ggerganov/llama.cpp/pull/5021 -- BPE pre-tokenization support has been added: https://github.com/ggerganov/llama.cpp/pull/6920 -- MoE memory layout has been updated - reconvert models for `mmap` support and regenerate `imatrix` https://github.com/ggerganov/llama.cpp/pull/6387 -- Model sharding instructions using `gguf-split` https://github.com/ggerganov/llama.cpp/discussions/6404 -- Fix major bug in Metal batched inference https://github.com/ggerganov/llama.cpp/pull/6225 -- Multi-GPU pipeline parallelism support https://github.com/ggerganov/llama.cpp/pull/6017 -- Looking for contributions to add Deepseek support: https://github.com/ggerganov/llama.cpp/issues/5981 -- Quantization blind testing: https://github.com/ggerganov/llama.cpp/discussions/5962 -- Initial Mamba support has been added: https://github.com/ggerganov/llama.cpp/pull/5328 +- *add hot topics here* ---- From 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Tue, 3 Sep 2024 20:58:54 +0300 Subject: [PATCH 15/22] llama-bench : add JSONL (NDJSON) output mode (#9288) * llama-bench : add JSONL (NDJSON) output mode * llama-bench : update usage docs --- examples/llama-bench/README.md | 65 +++++++----- examples/llama-bench/llama-bench.cpp | 143 ++++++++++++++++----------- 2 files changed, 127 insertions(+), 81 deletions(-) diff --git a/examples/llama-bench/README.md b/examples/llama-bench/README.md index 52b0e74d3..6bbe4bb75 100644 --- a/examples/llama-bench/README.md +++ b/examples/llama-bench/README.md @@ -14,7 +14,8 @@ Performance testing tool for llama.cpp. 1. [Markdown](#markdown) 2. [CSV](#csv) 3. [JSON](#json) - 4. [SQL](#sql) + 4. [JSONL](#jsonl) + 5. [SQL](#sql) ## Syntax @@ -23,27 +24,34 @@ usage: ./llama-bench [options] options: -h, --help - -m, --model (default: models/7B/ggml-model-q4_0.gguf) - -p, --n-prompt (default: 512) - -n, --n-gen (default: 128) - -pg (default: 512,128) - -b, --batch-size (default: 2048) - -ub, --ubatch-size (default: 512) - -ctk, --cache-type-k (default: f16) - -ctv, --cache-type-v (default: f16) - -t, --threads (default: 16) - -ngl, --n-gpu-layers (default: 99) - -sm, --split-mode (default: layer) - -mg, --main-gpu (default: 0) - -nkvo, --no-kv-offload <0|1> (default: 0) - -fa, --flash-attn <0|1> (default: 0) - -mmp, --mmap <0|1> (default: 1) - --numa (default: disabled) - -embd, --embeddings <0|1> (default: 0) - -ts, --tensor-split (default: 0) - -r, --repetitions (default: 5) - -o, --output (default: md) - -v, --verbose (default: 0) + -m, --model (default: models/7B/ggml-model-q4_0.gguf) + -p, --n-prompt (default: 512) + -n, --n-gen (default: 128) + -pg (default: ) + -b, --batch-size (default: 2048) + -ub, --ubatch-size (default: 512) + -ctk, --cache-type-k (default: f16) + -ctv, --cache-type-v (default: f16) + -t, --threads (default: 8) + -C, --cpu-mask (default: 0x0) + --cpu-strict <0|1> (default: 0) + --poll <0...100> (default: 50) + -ngl, --n-gpu-layers (default: 99) + -rpc, --rpc (default: ) + -sm, --split-mode (default: layer) + -mg, --main-gpu (default: 0) + -nkvo, --no-kv-offload <0|1> (default: 0) + -fa, --flash-attn <0|1> (default: 0) + -mmp, --mmap <0|1> (default: 1) + --numa (default: disabled) + -embd, --embeddings <0|1> (default: 0) + -ts, --tensor-split (default: 0) + -r, --repetitions (default: 5) + --prio <0|1|2|3> (default: 0) + --delay <0...N> (seconds) (default: 0) + -o, --output (default: md) + -oe, --output-err (default: none) + -v, --verbose (default: 0) Multiple values can be given for each parameter by separating them with ',' or by specifying the parameter multiple times. ``` @@ -238,6 +246,19 @@ $ ./llama-bench -o json ] ``` + +### JSONL + +```sh +$ ./llama-bench -o jsonl +``` + +```json lines +{"build_commit":"3469684","build_number":1275,"cuda":true,"metal":false,"gpu_blas":true,"blas":true,"cpu_info":"13th Gen Intel(R) Core(TM) i9-13900K","gpu_info":"NVIDIA GeForce RTX 3090 Ti","model_filename":"models/7B/ggml-model-q4_0.gguf","model_type":"llama 7B mostly Q4_0","model_size":3825065984,"model_n_params":6738415616,"n_batch":512,"n_threads":16,"f16_kv":true,"n_gpu_layers":99,"main_gpu":0,"mul_mat_q":true,"tensor_split":"0.00","n_prompt":512,"n_gen":0,"test_time":"2023-09-23T12:09:57Z","avg_ns":212365953,"stddev_ns":985423,"avg_ts":2410.974041,"stddev_ts":11.163766,"samples_ns":[213837238,211635853,212328053,211329715,212698907],"samples_ts":[2394.34,2419.25,2411.36,2422.75,2407.16]} +{"build_commit":"3469684","build_number":1275,"cuda":true,"metal":false,"gpu_blas":true,"blas":true,"cpu_info":"13th Gen Intel(R) Core(TM) i9-13900K","gpu_info":"NVIDIA GeForce RTX 3090 Ti","model_filename":"models/7B/ggml-model-q4_0.gguf","model_type":"llama 7B mostly Q4_0","model_size":3825065984,"model_n_params":6738415616,"n_batch":512,"n_threads":16,"f16_kv":true,"n_gpu_layers":99,"main_gpu":0,"mul_mat_q":true,"tensor_split":"0.00","n_prompt":0,"n_gen":128,"test_time":"2023-09-23T12:09:59Z","avg_ns":977425219,"stddev_ns":9268593,"avg_ts":130.965708,"stddev_ts":1.238924,"samples_ns":[984472709,974901233,989474741,970729355,967548060],"samples_ts":[130.019,131.295,129.362,131.86,132.293]} +``` + + ### SQL SQL output is suitable for importing into a SQLite database. The output can be piped into the `sqlite3` command line tool to add the results to a database. diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 8edadef90..b65662ecd 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -171,13 +171,14 @@ static std::string get_gpu_info() { } // command line params -enum output_formats {NONE, CSV, JSON, MARKDOWN, SQL}; +enum output_formats {NONE, CSV, JSON, JSONL, MARKDOWN, SQL}; static const char * output_format_str(output_formats format) { switch (format) { case NONE: return "none"; case CSV: return "csv"; case JSON: return "json"; + case JSONL: return "jsonl"; case MARKDOWN: return "md"; case SQL: return "sql"; default: GGML_ABORT("invalid output format"); @@ -191,6 +192,8 @@ static bool output_format_from_str(const std::string & s, output_formats & forma format = CSV; } else if (s == "json") { format = JSON; + } else if (s == "jsonl") { + format = JSONL; } else if (s == "md") { format = MARKDOWN; } else if (s == "sql") { @@ -283,34 +286,34 @@ static void print_usage(int /* argc */, char ** argv) { printf("\n"); printf("options:\n"); printf(" -h, --help\n"); - printf(" -m, --model (default: %s)\n", join(cmd_params_defaults.model, ",").c_str()); - printf(" -p, --n-prompt (default: %s)\n", join(cmd_params_defaults.n_prompt, ",").c_str()); - printf(" -n, --n-gen (default: %s)\n", join(cmd_params_defaults.n_gen, ",").c_str()); - printf(" -pg (default: %s)\n", join(transform_to_str(cmd_params_defaults.n_pg, pair_str), ",").c_str()); - printf(" -b, --batch-size (default: %s)\n", join(cmd_params_defaults.n_batch, ",").c_str()); - printf(" -ub, --ubatch-size (default: %s)\n", join(cmd_params_defaults.n_ubatch, ",").c_str()); - printf(" -ctk, --cache-type-k (default: %s)\n", join(transform_to_str(cmd_params_defaults.type_k, ggml_type_name), ",").c_str()); - printf(" -ctv, --cache-type-v (default: %s)\n", join(transform_to_str(cmd_params_defaults.type_v, ggml_type_name), ",").c_str()); - printf(" -t, --threads (default: %s)\n", join(cmd_params_defaults.n_threads, ",").c_str()); - printf(" -C, --cpu-mask (default: %s)\n", join(cmd_params_defaults.cpu_mask, ",").c_str()); - printf(" --cpu-strict <0|1> (default: %s)\n", join(cmd_params_defaults.cpu_strict, ",").c_str()); - printf(" --poll <0...100> (default: %s)\n", join(cmd_params_defaults.poll, ",").c_str()); - printf(" -ngl, --n-gpu-layers (default: %s)\n", join(cmd_params_defaults.n_gpu_layers, ",").c_str()); - printf(" -rpc, --rpc (default: %s)\n", join(cmd_params_defaults.rpc_servers, ",").c_str()); - printf(" -sm, --split-mode (default: %s)\n", join(transform_to_str(cmd_params_defaults.split_mode, split_mode_str), ",").c_str()); - printf(" -mg, --main-gpu (default: %s)\n", join(cmd_params_defaults.main_gpu, ",").c_str()); - printf(" -nkvo, --no-kv-offload <0|1> (default: %s)\n", join(cmd_params_defaults.no_kv_offload, ",").c_str()); - printf(" -fa, --flash-attn <0|1> (default: %s)\n", join(cmd_params_defaults.flash_attn, ",").c_str()); - printf(" -mmp, --mmap <0|1> (default: %s)\n", join(cmd_params_defaults.use_mmap, ",").c_str()); - printf(" --numa (default: disabled)\n"); - printf(" -embd, --embeddings <0|1> (default: %s)\n", join(cmd_params_defaults.embeddings, ",").c_str()); - printf(" -ts, --tensor-split (default: 0)\n"); - printf(" -r, --repetitions (default: %d)\n", cmd_params_defaults.reps); - printf(" --prio <0|1|2|3> (default: %d)\n", cmd_params_defaults.prio); - printf(" --delay <0...N> (seconds) (default: %d)\n", cmd_params_defaults.delay); - printf(" -o, --output (default: %s)\n", output_format_str(cmd_params_defaults.output_format)); - printf(" -oe, --output-err (default: %s)\n", output_format_str(cmd_params_defaults.output_format_stderr)); - printf(" -v, --verbose (default: %s)\n", cmd_params_defaults.verbose ? "1" : "0"); + printf(" -m, --model (default: %s)\n", join(cmd_params_defaults.model, ",").c_str()); + printf(" -p, --n-prompt (default: %s)\n", join(cmd_params_defaults.n_prompt, ",").c_str()); + printf(" -n, --n-gen (default: %s)\n", join(cmd_params_defaults.n_gen, ",").c_str()); + printf(" -pg (default: %s)\n", join(transform_to_str(cmd_params_defaults.n_pg, pair_str), ",").c_str()); + printf(" -b, --batch-size (default: %s)\n", join(cmd_params_defaults.n_batch, ",").c_str()); + printf(" -ub, --ubatch-size (default: %s)\n", join(cmd_params_defaults.n_ubatch, ",").c_str()); + printf(" -ctk, --cache-type-k (default: %s)\n", join(transform_to_str(cmd_params_defaults.type_k, ggml_type_name), ",").c_str()); + printf(" -ctv, --cache-type-v (default: %s)\n", join(transform_to_str(cmd_params_defaults.type_v, ggml_type_name), ",").c_str()); + printf(" -t, --threads (default: %s)\n", join(cmd_params_defaults.n_threads, ",").c_str()); + printf(" -C, --cpu-mask (default: %s)\n", join(cmd_params_defaults.cpu_mask, ",").c_str()); + printf(" --cpu-strict <0|1> (default: %s)\n", join(cmd_params_defaults.cpu_strict, ",").c_str()); + printf(" --poll <0...100> (default: %s)\n", join(cmd_params_defaults.poll, ",").c_str()); + printf(" -ngl, --n-gpu-layers (default: %s)\n", join(cmd_params_defaults.n_gpu_layers, ",").c_str()); + printf(" -rpc, --rpc (default: %s)\n", join(cmd_params_defaults.rpc_servers, ",").c_str()); + printf(" -sm, --split-mode (default: %s)\n", join(transform_to_str(cmd_params_defaults.split_mode, split_mode_str), ",").c_str()); + printf(" -mg, --main-gpu (default: %s)\n", join(cmd_params_defaults.main_gpu, ",").c_str()); + printf(" -nkvo, --no-kv-offload <0|1> (default: %s)\n", join(cmd_params_defaults.no_kv_offload, ",").c_str()); + printf(" -fa, --flash-attn <0|1> (default: %s)\n", join(cmd_params_defaults.flash_attn, ",").c_str()); + printf(" -mmp, --mmap <0|1> (default: %s)\n", join(cmd_params_defaults.use_mmap, ",").c_str()); + printf(" --numa (default: disabled)\n"); + printf(" -embd, --embeddings <0|1> (default: %s)\n", join(cmd_params_defaults.embeddings, ",").c_str()); + printf(" -ts, --tensor-split (default: 0)\n"); + printf(" -r, --repetitions (default: %d)\n", cmd_params_defaults.reps); + printf(" --prio <0|1|2|3> (default: %d)\n", cmd_params_defaults.prio); + printf(" --delay <0...N> (seconds) (default: %d)\n", cmd_params_defaults.delay); + printf(" -o, --output (default: %s)\n", output_format_str(cmd_params_defaults.output_format)); + printf(" -oe, --output-err (default: %s)\n", output_format_str(cmd_params_defaults.output_format_stderr)); + printf(" -v, --verbose (default: %s)\n", cmd_params_defaults.verbose ? "1" : "0"); printf("\n"); printf("Multiple values can be given for each parameter by separating them with ',' or by specifying the parameter multiple times.\n"); } @@ -1074,38 +1077,39 @@ struct csv_printer : public printer { } }; + +static std::string escape_json(const std::string & value) { + std::string escaped; + for (auto c : value) { + if (c == '"') { + escaped += "\\\""; + } else if (c == '\\') { + escaped += "\\\\"; + } else if (c <= 0x1f) { + char buf[8]; + snprintf(buf, sizeof(buf), "\\u%04x", c); + escaped += buf; + } else { + escaped += c; + } + } + return escaped; +} + +static std::string format_json_value(const std::string & field, const std::string & value) { + switch (test::get_field_type(field)) { + case test::STRING: + return "\"" + escape_json(value) + "\""; + case test::BOOL: + return value == "0" ? "false" : "true"; + default: + return value; + } +} + struct json_printer : public printer { bool first = true; - static std::string escape_json(const std::string & value) { - std::string escaped; - for (auto c : value) { - if (c == '"') { - escaped += "\\\""; - } else if (c == '\\') { - escaped += "\\\\"; - } else if (c <= 0x1f) { - char buf[8]; - snprintf(buf, sizeof(buf), "\\u%04x", c); - escaped += buf; - } else { - escaped += c; - } - } - return escaped; - } - - static std::string format_value(const std::string & field, const std::string & value) { - switch (test::get_field_type(field)) { - case test::STRING: - return "\"" + escape_json(value) + "\""; - case test::BOOL: - return value == "0" ? "false" : "true"; - default: - return value; - } - } - void print_header(const cmd_params & params) override { fprintf(fout, "[\n"); (void) params; @@ -1114,7 +1118,7 @@ struct json_printer : public printer { void print_fields(const std::vector & fields, const std::vector & values) { assert(fields.size() == values.size()); for (size_t i = 0; i < fields.size(); i++) { - fprintf(fout, " \"%s\": %s,\n", fields.at(i).c_str(), format_value(fields.at(i), values.at(i)).c_str()); + fprintf(fout, " \"%s\": %s,\n", fields.at(i).c_str(), format_json_value(fields.at(i), values.at(i)).c_str()); } } @@ -1137,6 +1141,25 @@ struct json_printer : public printer { } }; + +struct jsonl_printer : public printer { + void print_fields(const std::vector & fields, const std::vector & values) { + assert(fields.size() == values.size()); + for (size_t i = 0; i < fields.size(); i++) { + fprintf(fout, "\"%s\": %s, ", fields.at(i).c_str(), format_json_value(fields.at(i), values.at(i)).c_str()); + } + } + + void print_test(const test & t) override { + fprintf(fout, "{"); + print_fields(test::get_fields(), t.get_values()); + fprintf(fout, "\"samples_ns\": [ %s ],", join(t.samples_ns, ", ").c_str()); + fprintf(fout, "\"samples_ts\": [ %s ]", join(t.get_ts(), ", ").c_str()); + fprintf(fout, "}\n"); + fflush(fout); + } +}; + struct markdown_printer : public printer { std::vector fields; @@ -1437,6 +1460,8 @@ static std::unique_ptr create_printer(output_formats format) { return std::unique_ptr(new csv_printer()); case JSON: return std::unique_ptr(new json_printer()); + case JSONL: + return std::unique_ptr(new jsonl_printer()); case MARKDOWN: return std::unique_ptr(new markdown_printer()); case SQL: From 7605ae7daf31c02211bcfec2f46635ef6ec4b98a Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 4 Sep 2024 02:36:43 +0300 Subject: [PATCH 16/22] flake.lock: Update (#9261) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Flake lock file updates: • Updated input 'flake-parts': 'github:hercules-ci/flake-parts/8471fe90ad337a8074e957b69ca4d0089218391d?narHash=sha256-XOQkdLafnb/p9ij77byFQjDf5m5QYl9b2REiVClC%2Bx4%3D' (2024-08-01) → 'github:hercules-ci/flake-parts/af510d4a62d071ea13925ce41c95e3dec816c01d?narHash=sha256-ODYRm8zHfLTH3soTFWE452ydPYz2iTvr9T8ftDMUQ3E%3D' (2024-08-30) • Updated input 'nixpkgs': 'github:NixOS/nixpkgs/c374d94f1536013ca8e92341b540eba4c22f9c62?narHash=sha256-Z/ELQhrSd7bMzTO8r7NZgi9g5emh%2BaRKoCdaAv5fiO0%3D' (2024-08-21) → 'github:NixOS/nixpkgs/71e91c409d1e654808b2621f28a327acfdad8dc2?narHash=sha256-GnR7/ibgIH1vhoy8cYdmXE6iyZqKqFxQSVkFgosBh6w%3D' (2024-08-28) Co-authored-by: github-actions[bot] --- flake.lock | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/flake.lock b/flake.lock index cc1ebe299..10e1f8a29 100644 --- a/flake.lock +++ b/flake.lock @@ -5,11 +5,11 @@ "nixpkgs-lib": "nixpkgs-lib" }, "locked": { - "lastModified": 1722555600, - "narHash": "sha256-XOQkdLafnb/p9ij77byFQjDf5m5QYl9b2REiVClC+x4=", + "lastModified": 1725024810, + "narHash": "sha256-ODYRm8zHfLTH3soTFWE452ydPYz2iTvr9T8ftDMUQ3E=", "owner": "hercules-ci", "repo": "flake-parts", - "rev": "8471fe90ad337a8074e957b69ca4d0089218391d", + "rev": "af510d4a62d071ea13925ce41c95e3dec816c01d", "type": "github" }, "original": { @@ -20,11 +20,11 @@ }, "nixpkgs": { "locked": { - "lastModified": 1724224976, - "narHash": "sha256-Z/ELQhrSd7bMzTO8r7NZgi9g5emh+aRKoCdaAv5fiO0=", + "lastModified": 1724819573, + "narHash": "sha256-GnR7/ibgIH1vhoy8cYdmXE6iyZqKqFxQSVkFgosBh6w=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "c374d94f1536013ca8e92341b540eba4c22f9c62", + "rev": "71e91c409d1e654808b2621f28a327acfdad8dc2", "type": "github" }, "original": { From 9379d3cc1718e9ec6ab5ffcb60a06bbe048a61ee Mon Sep 17 00:00:00 2001 From: Pascal Patry Date: Wed, 4 Sep 2024 02:45:40 -0400 Subject: [PATCH 17/22] readme : rename result_format to response_format (#9300) --- grammars/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grammars/README.md b/grammars/README.md index 01b02abb4..7ec815471 100644 --- a/grammars/README.md +++ b/grammars/README.md @@ -120,7 +120,7 @@ You can use GBNF grammars: - In [llama-server](../examples/server): - For any completion endpoints, passed as the `json_schema` body field - - For the `/chat/completions` endpoint, passed inside the `result_format` body field (e.g. `{"type", "json_object", "schema": {"items": {}}}`) + - For the `/chat/completions` endpoint, passed inside the `response_format` body field (e.g. `{"type", "json_object", "schema": {"items": {}}}`) - In [llama-cli](../examples/main), passed as the `--json` / `-j` flag - To convert to a grammar ahead of time: - in CLI, with [examples/json_schema_to_grammar.py](../examples/json_schema_to_grammar.py) From 82e3b03c11826d20a24ab66d60f4de58f48ddcdb Mon Sep 17 00:00:00 2001 From: Radoslav Gerganov Date: Wed, 4 Sep 2024 11:08:32 +0300 Subject: [PATCH 18/22] rpc : make RPC servers come first in the device list (#9296) * rpc : make RPC servers come first in the device list * rpc : disable options for non-RPC builds * rpc : rpc_count always zero for non-RPC builds --- common/common.cpp | 4 ++ examples/llama-bench/llama-bench.cpp | 4 ++ src/llama.cpp | 75 ++++++++++++++++------------ 3 files changed, 50 insertions(+), 33 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 9fa184725..d22b04968 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1234,11 +1234,13 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa #endif // GGML_USE_CUDA_SYCL_VULKAN return true; } +#ifdef GGML_USE_RPC if (arg == "--rpc") { CHECK_ARG params.rpc_servers = argv[i]; return true; } +#endif if (arg == "--no-mmap") { params.use_mmap = false; return true; @@ -1929,7 +1931,9 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "*", " --image FILE", "path to an image file. use with multimodal models. Specify multiple times for batching" }); options.push_back({ "backend" }); +#ifdef GGML_USE_RPC options.push_back({ "*", " --rpc SERVERS", "comma separated list of RPC servers" }); +#endif if (llama_supports_mlock()) { options.push_back({ "*", " --mlock", "force system to keep model in RAM rather than swapping or compressing" }); diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index b65662ecd..bf97cc684 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -299,7 +299,9 @@ static void print_usage(int /* argc */, char ** argv) { printf(" --cpu-strict <0|1> (default: %s)\n", join(cmd_params_defaults.cpu_strict, ",").c_str()); printf(" --poll <0...100> (default: %s)\n", join(cmd_params_defaults.poll, ",").c_str()); printf(" -ngl, --n-gpu-layers (default: %s)\n", join(cmd_params_defaults.n_gpu_layers, ",").c_str()); +#ifdef GGML_USE_RPC printf(" -rpc, --rpc (default: %s)\n", join(cmd_params_defaults.rpc_servers, ",").c_str()); +#endif printf(" -sm, --split-mode (default: %s)\n", join(transform_to_str(cmd_params_defaults.split_mode, split_mode_str), ",").c_str()); printf(" -mg, --main-gpu (default: %s)\n", join(cmd_params_defaults.main_gpu, ",").c_str()); printf(" -nkvo, --no-kv-offload <0|1> (default: %s)\n", join(cmd_params_defaults.no_kv_offload, ",").c_str()); @@ -482,12 +484,14 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { } auto p = string_split(argv[i], split_delim); params.n_gpu_layers.insert(params.n_gpu_layers.end(), p.begin(), p.end()); +#ifdef GGML_USE_RPC } else if (arg == "-rpc" || arg == "--rpc") { if (++i >= argc) { invalid_param = true; break; } params.rpc_servers.push_back(argv[i]); +#endif } else if (arg == "-sm" || arg == "--split-mode") { if (++i >= argc) { invalid_param = true; diff --git a/src/llama.cpp b/src/llama.cpp index 883559716..c3669eb28 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -3346,29 +3346,33 @@ static size_t llama_get_device_count(const llama_model & model) { static ggml_backend_buffer_type_t llama_default_buffer_type_offload(const llama_model & model, int gpu) { ggml_backend_buffer_type_t buft = nullptr; -#if defined(GGML_USE_RPC) - int dev_count = (int)llama_get_device_count(model); +#ifdef GGML_USE_RPC int rpc_count = (int)model.rpc_servers.size(); - if (gpu >= dev_count - rpc_count) { - const char * endpoint = model.rpc_servers[gpu - dev_count + rpc_count].c_str(); +#else + int rpc_count = 0; +#endif + int local_gpu = gpu - rpc_count; +#if defined(GGML_USE_RPC) + if (gpu < rpc_count) { + const char * endpoint = model.rpc_servers[gpu].c_str(); return ggml_backend_rpc_buffer_type(endpoint); } #endif #if defined(GGML_USE_METAL) buft = ggml_backend_metal_buffer_type(); #elif defined(GGML_USE_CUDA) - buft = ggml_backend_cuda_buffer_type(gpu); + buft = ggml_backend_cuda_buffer_type(local_gpu); #elif defined(GGML_USE_VULKAN) - buft = ggml_backend_vk_buffer_type(gpu); + buft = ggml_backend_vk_buffer_type(local_gpu); #elif defined(GGML_USE_SYCL) - buft = ggml_backend_sycl_buffer_type(gpu); + buft = ggml_backend_sycl_buffer_type(local_gpu); #elif defined(GGML_USE_KOMPUTE) - buft = ggml_backend_kompute_buffer_type(gpu); + buft = ggml_backend_kompute_buffer_type(local_gpu); if (buft == nullptr) { - LLAMA_LOG_WARN("%s: cannot use GPU %d, check `vulkaninfo --summary`\n", __func__, gpu); + LLAMA_LOG_WARN("%s: cannot use GPU %d, check `vulkaninfo --summary`\n", __func__, local_gpu); } #elif defined(GGML_USE_CANN) - buft = ggml_backend_cann_buffer_type(gpu); + buft = ggml_backend_cann_buffer_type(local_gpu); #endif if (buft == nullptr) { @@ -3376,7 +3380,7 @@ static ggml_backend_buffer_type_t llama_default_buffer_type_offload(const llama_ } return buft; GGML_UNUSED(model); - GGML_UNUSED(gpu); + GGML_UNUSED(local_gpu); } static ggml_backend_buffer_type_t llama_default_buffer_type_split(const llama_model & model, int fallback_gpu, const float * tensor_split) { @@ -3403,13 +3407,17 @@ static ggml_backend_buffer_type_t llama_default_buffer_type_split(const llama_mo } static size_t llama_get_device_memory(const llama_model & model, int device) { -#if defined(GGML_USE_RPC) - int dev_count = (int)llama_get_device_count(model); +#ifdef GGML_USE_RPC int rpc_count = (int)model.rpc_servers.size(); - if (device >= dev_count - rpc_count) { +#else + int rpc_count = 0; +#endif + int local_device = device - rpc_count; +#if defined(GGML_USE_RPC) + if (device < rpc_count) { size_t total; size_t free; - const char * endpoint = model.rpc_servers[device - dev_count + rpc_count].c_str(); + const char * endpoint = model.rpc_servers[device].c_str(); ggml_backend_rpc_get_device_memory(endpoint, &free, &total); return free; } @@ -3417,28 +3425,28 @@ static size_t llama_get_device_memory(const llama_model & model, int device) { #if defined(GGML_USE_CUDA) size_t total; size_t free; - ggml_backend_cuda_get_device_memory(device, &free, &total); + ggml_backend_cuda_get_device_memory(local_device, &free, &total); return free; #elif defined(GGML_USE_SYCL) size_t total; size_t free; - ggml_backend_sycl_get_device_memory(device, &free, &total); + ggml_backend_sycl_get_device_memory(local_device, &free, &total); return free; #elif defined(GGML_USE_VULKAN) size_t total; size_t free; - ggml_backend_vk_get_device_memory(device, &free, &total); + ggml_backend_vk_get_device_memory(local_device, &free, &total); return free; #elif defined(GGML_USE_CANN) size_t total; size_t free; - ggml_backend_cann_get_device_memory(device, &free, &total); + ggml_backend_cann_get_device_memory(local_device, &free, &total); return free; #else return 1; #endif GGML_UNUSED(model); - GGML_UNUSED(device); + GGML_UNUSED(local_device); } // @@ -18186,6 +18194,20 @@ struct llama_context * llama_new_context_with_model( if (!hparams.vocab_only) { // initialize backends +#if defined(GGML_USE_RPC) + if (model->n_gpu_layers > 0) { + for (const auto & endpoint : model->rpc_servers) { + ggml_backend_t backend = ggml_backend_rpc_init(endpoint.c_str()); + if (backend == nullptr) { + LLAMA_LOG_ERROR("%s: failed to initialize RPC to '%s'\n", __func__, endpoint.c_str()); + llama_free(ctx); + return nullptr; + } + ctx->backends.push_back(backend); + } + } +#endif + #if defined(GGML_USE_METAL) if (model->n_gpu_layers > 0) { ctx->backend_metal = ggml_backend_metal_init(); @@ -18310,19 +18332,6 @@ struct llama_context * llama_new_context_with_model( } #endif -#if defined(GGML_USE_RPC) - if (model->n_gpu_layers > 0) { - for (const auto & endpoint : model->rpc_servers) { - ggml_backend_t backend = ggml_backend_rpc_init(endpoint.c_str()); - if (backend == nullptr) { - LLAMA_LOG_ERROR("%s: failed to initialize RPC to '%s'\n", __func__, endpoint.c_str()); - llama_free(ctx); - return nullptr; - } - ctx->backends.push_back(backend); - } - } -#endif ctx->backend_cpu = ggml_backend_cpu_init(); if (ctx->backend_cpu == nullptr) { LLAMA_LOG_ERROR("%s: failed to initialize CPU backend\n", __func__); From c8671ae28214b29920b86c66a5d36fd6dd0db870 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=A8=E6=9C=B1=20=C2=B7=20Kiki?= Date: Wed, 4 Sep 2024 19:45:28 +0800 Subject: [PATCH 19/22] Fix broken links in docker.md (#9306) --- docs/docker.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/docker.md b/docs/docker.md index e25838255..e8a084173 100644 --- a/docs/docker.md +++ b/docs/docker.md @@ -20,7 +20,7 @@ Additionally, there the following images, similar to the above: - `ghcr.io/ggerganov/llama.cpp:light-rocm`: Same as `light` but compiled with ROCm support. (platforms: `linux/amd64`, `linux/arm64`) - `ghcr.io/ggerganov/llama.cpp:server-rocm`: Same as `server` but compiled with ROCm support. (platforms: `linux/amd64`, `linux/arm64`) -The GPU enabled images are not currently tested by CI beyond being built. They are not built with any variation from the ones in the Dockerfiles defined in [.devops/](.devops/) and the GitHub Action defined in [.github/workflows/docker.yml](.github/workflows/docker.yml). If you need different settings (for example, a different CUDA or ROCm library, you'll need to build the images locally for now). +The GPU enabled images are not currently tested by CI beyond being built. They are not built with any variation from the ones in the Dockerfiles defined in [.devops/](../.devops/) and the GitHub Action defined in [.github/workflows/docker.yml](../.github/workflows/docker.yml). If you need different settings (for example, a different CUDA or ROCm library, you'll need to build the images locally for now). ## Usage From 5910ea942772ab6cbc21d0ad2d1208750ba39e1d Mon Sep 17 00:00:00 2001 From: Ouadie EL FAROUKI Date: Wed, 4 Sep 2024 16:26:33 +0100 Subject: [PATCH 20/22] [SYCL] Fix DMMV dequantization (#9279) Fixed dmmv dequant for ncols== GGML_SYCL_DMMV_X --- ggml/src/ggml-sycl/dmmv.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-sycl/dmmv.cpp b/ggml/src/ggml-sycl/dmmv.cpp index 5c343822f..0c3dfaa37 100644 --- a/ggml/src/ggml-sycl/dmmv.cpp +++ b/ggml/src/ggml-sycl/dmmv.cpp @@ -76,8 +76,8 @@ static void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * } // sum up partial sums and write back result -#pragma unroll - for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { + const int mask_start = ncols > GGML_SYCL_DMMV_X ? WARP_SIZE >> 1 : WARP_SIZE >> 2; + for (int mask = mask_start; mask > 0; mask >>= 1) { tmp += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); } From 581c305186a0ff93f360346c57e21fe16e967bb7 Mon Sep 17 00:00:00 2001 From: Srihari-mcw <96763064+Srihari-mcw@users.noreply.github.com> Date: Wed, 4 Sep 2024 22:21:22 +0530 Subject: [PATCH 21/22] ggml : AVX2 support for Q4_0_8_8 (#8713) * Add AVX2 based implementations for quantize_q8_0_4x8, ggml_gemv_q4_0_8x8_q8_0 and ggml_gemm_q4_0_8x8_q8_0 functions * Update code to fix issues occuring due to non alignment of elements to be processed as multiple of 16 in MSVC * Update comments and indentation * Make updates to reduce number of load instructions --- ggml/src/ggml-aarch64.c | 612 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 612 insertions(+) diff --git a/ggml/src/ggml-aarch64.c b/ggml/src/ggml-aarch64.c index 332578fd4..72cb83c9b 100644 --- a/ggml/src/ggml-aarch64.c +++ b/ggml/src/ggml-aarch64.c @@ -36,6 +36,84 @@ // from bias offset form to pure sign form (this saves subtract // operations durin unpacking) // +#if defined(__AVX__) +#if defined(__F16C__) +// the _mm256_cvt intrinsics require F16C +#define GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x))) +#define GGML_F32Cx8_REPEAT_LOAD(x, loadMask) _mm256_cvtph_ps(_mm_shuffle_epi32(_mm_maskload_epi32((int const*)(x), loadMask), 68)) +#define GGML_F32Cx8_REARRANGE_LOAD(x, arrangeMask) _mm256_cvtph_ps(_mm_shuffle_epi8(_mm_loadu_si128((const __m128i *) x), arrangeMask)) +#else +static inline __m256 __avx_f32cx8_load(ggml_fp16_t *x) { + float tmp[8]; + + for (int i = 0; i < 8; i++) { + tmp[i] = GGML_FP16_TO_FP32(x[i]); + } + + return _mm256_loadu_ps(tmp); +} +static inline __m256 __avx_repeat_f32cx8_load(ggml_fp16_t *x) { + float tmp[8]; + + for (int i = 0; i < 4; i++) { + tmp[i] = GGML_FP16_TO_FP32(x[i]); + tmp[i + 4] = GGML_FP16_TO_FP32(x[i]); + } + + return _mm256_loadu_ps(tmp); +} +static inline __m256 __avx_rearranged_f32cx8_load(ggml_fp16_t *x, __m128i arrangeMask) { + uint16_t tmphalf[8]; + float tmp[8]; + + _mm_storeu_si128((__m128i*)tmphalf, _mm_shuffle_epi8(_mm_loadu_si128((const __m128i *) x), arrangeMask)); + for (int i = 0; i < 8; i++) { + tmp[i] = GGML_FP16_TO_FP32(tmphalf[i]); + } + + return _mm256_loadu_ps(tmp); +} + +#define GGML_F32Cx8_LOAD(x) __avx_f32cx8_load(x) +#define GGML_F32Cx8_REPEAT_LOAD(x, loadMask) __avx_repeat_f32cx8_load(x) +#define GGML_F32Cx8_REARRANGE_LOAD(x, arrangeMask) __avx_rearranged_f32cx8_load(x, arrangeMask) +#endif +#endif + + +#if defined(__AVX2__) || defined(__AVX512F__) +static inline __m256i sum_i16_pairs_int(const __m256i x) { + const __m256i ones = _mm256_set1_epi16(1); + return _mm256_madd_epi16(ones, x); +} + +static inline __m256i mul_sum_us8_pairs_int(const __m256i ax, const __m256i sy) { +#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__)) + const __m256i zero = _mm256_setzero_si256(); + return _mm256_dpbusd_epi32(zero, ax, sy); +#else + // Perform multiplication and create 16-bit values + const __m256i dot = _mm256_maddubs_epi16(ax, sy); + return sum_i16_pairs_int(dot); +#endif +} + +// Integer variant of the function defined in ggml-quants.c +// multiply int8_t, add results pairwise twice and return as float vector +static inline __m256i mul_sum_i8_pairs_int(const __m256i x, const __m256i y) { +#if __AVXVNNIINT8__ + const __m256i zero = _mm256_setzero_si256(); + return _mm256_dpbssd_epi32(zero, x, y); +#else + // Get absolute values of x vectors + const __m256i ax = _mm256_sign_epi8(x, x); + // Sign the values of the y vectors + const __m256i sy = _mm256_sign_epi8(y, x); + return mul_sum_us8_pairs_int(ax, sy); +#endif +} +#endif + static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int blck_size_interleave, unsigned int xor_mask) { block_q4_0x4 out; @@ -255,6 +333,103 @@ void quantize_q8_0_4x8(const float * restrict x, void * restrict vy, int64_t k) y[i].qs[32 * j + 31] = vgetq_lane_s32(vi, 3); } } +#elif defined(__AVX2__) || defined(__AVX__) + float id[4]; + __m256 srcv[4][4]; + __m256 idvec[4]; + + for (int i = 0; i < nb; i++) { + for (int row_iter = 0; row_iter < 4; row_iter++) { + // Load elements into 4 AVX vectors + __m256 v0 = _mm256_loadu_ps( x + row_iter * k + i * 32 ); + __m256 v1 = _mm256_loadu_ps( x + row_iter * k + i * 32 + 8 ); + __m256 v2 = _mm256_loadu_ps( x + row_iter * k + i * 32 + 16 ); + __m256 v3 = _mm256_loadu_ps( x + row_iter * k + i * 32 + 24 ); + + // Compute max(abs(e)) for the block + const __m256 signBit = _mm256_set1_ps( -0.0f ); + __m256 maxAbs = _mm256_andnot_ps( signBit, v0 ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) ); + + __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) ); + max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) ); + max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) ); + const float maxScalar = _mm_cvtss_f32( max4 ); + + // Divided by 127.f to mirror results in quantize_row_q8_0 + const float d = maxScalar / 127.f; + id[row_iter] = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f; //d ? 1.0f / d : 0.0f; + + // Store the scale for the individual block + y[i].d[row_iter] = GGML_FP32_TO_FP16(d); + + // Store the values in blocks of eight values - Aim is to use these later for block interleaving + srcv[row_iter][0] = v0; + srcv[row_iter][1] = v1; + srcv[row_iter][2] = v2; + srcv[row_iter][3] = v3; + idvec[row_iter] = _mm256_set1_ps(id[row_iter]); + } + + // The loop iterates four times - The aim is to get 4 corresponding chunks of eight bytes from the original weight blocks that are interleaved + for (int j = 0; j < 4; j++) { + // Apply the multiplier + __m256 v0 = _mm256_mul_ps(srcv[0][j], idvec[0]); + __m256 v1 = _mm256_mul_ps(srcv[1][j], idvec[1]); + __m256 v2 = _mm256_mul_ps(srcv[2][j], idvec[2]); + __m256 v3 = _mm256_mul_ps(srcv[3][j], idvec[3]); + + // Round to nearest integer + v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST ); + v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST ); + v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST ); + v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST ); + + // Convert floats to integers + __m256i i0 = _mm256_cvtps_epi32( v0 ); + __m256i i1 = _mm256_cvtps_epi32( v1 ); + __m256i i2 = _mm256_cvtps_epi32( v2 ); + __m256i i3 = _mm256_cvtps_epi32( v3 ); + +#if defined(__AVX2__) + // Convert int32 to int16 + i0 = _mm256_packs_epi32( i0, i1 ); + i2 = _mm256_packs_epi32( i2, i3 ); + // Convert int16 to int8 + i0 = _mm256_packs_epi16( i0, i2 ); + + // Permute and store the quantized weights in the required order after the pack instruction + const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 ); + i0 = _mm256_permutevar8x32_epi32( i0, perm ); + + _mm256_storeu_si256((__m256i *)(y[i].qs + 32 * j), i0); +#else + // Since we don't have in AVX some necessary functions, + // we split the registers in half and call AVX2 analogs from SSE + __m128i ni0 = _mm256_castsi256_si128( i0 ); + __m128i ni1 = _mm256_extractf128_si256( i0, 1); + __m128i ni2 = _mm256_castsi256_si128( i1 ); + __m128i ni3 = _mm256_extractf128_si256( i1, 1); + __m128i ni4 = _mm256_castsi256_si128( i2 ); + __m128i ni5 = _mm256_extractf128_si256( i2, 1); + __m128i ni6 = _mm256_castsi256_si128( i3 ); + __m128i ni7 = _mm256_extractf128_si256( i3, 1); + + // Convert int32 to int16 + ni0 = _mm_packs_epi32( ni0, ni1 ); + ni2 = _mm_packs_epi32( ni2, ni3 ); + ni4 = _mm_packs_epi32( ni4, ni5 ); + ni6 = _mm_packs_epi32( ni6, ni7 ); + // Convert int16 to int8 + ni0 = _mm_packs_epi16( ni0, ni2 ); + ni4 = _mm_packs_epi16( ni4, ni6 ); + _mm_storeu_si128((__m128i *)(y[i].qs + 32 * j), ni0); + _mm_storeu_si128((__m128i *)(y[i].qs + 32 * j + 16), ni4); +#endif + } + } #else // scalar const int blck_size_interleave = 8; @@ -684,6 +859,96 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * GGML_ASSERT((ggml_cpu_has_sve() || ggml_cpu_has_matmul_int8()) && "__ARM_FEATURE_SVE and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 quantization format for optimal " "performance"); +#elif defined(__AVX2__) + // Lookup table to convert signed nibbles to signed bytes + __m256i signextendlut = _mm256_castsi128_si256(_mm_set_epi8(-1, -2, -3, -4, -5, -6, -7, -8, 7, 6, 5, 4, 3, 2, 1, 0)); + signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0); + __m128i changemask = _mm_set_epi8(15, 14, 7, 6, 13, 12, 5, 4, 11, 10, 3, 2, 9, 8, 1, 0); + __m256i finalpermutemask = _mm256_set_epi32(7, 5, 3, 1, 6, 4, 2, 0); + + // Permute mask used for easier vector processing at later stages + const __m256i m4b = _mm256_set1_epi8(0x0F); + + int64_t b_nb = n / QK4_0; + + const block_q4_0x8 * b_ptr_start = (const block_q4_0x8 *)vx; + const block_q8_0 * a_ptr_start = (const block_q8_0 *)vy; + + // Process Q8_0 blocks one by one + for (int64_t y = 0; y < nr; y++) { + + // Pointers to LHS blocks of block_q8_0 format + const block_q8_0 * a_ptr = a_ptr_start + (y * nb); + + // Take group of eight block_q4_0x8 structures at each pass of the loop and perform dot product operation + for (int64_t x = 0; x < nc / 8; x++) { + + // Pointers to RHS blocks + const block_q4_0x8 * b_ptr = b_ptr_start + (x * b_nb); + + // Master FP accumulator + __m256 acc_row = _mm256_setzero_ps(); + + for (int64_t b = 0; b < nb; b++) { + // Load 8 blocks of Q4_0 interleaved as 8 bytes (B0 - B7) + const __m256i rhs_raw_vec_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs)); + const __m256i rhs_raw_vec_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs) + 1); + const __m256i rhs_raw_vec_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs) + 2); + const __m256i rhs_raw_vec_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs) + 3); + + // 4-bit -> 8-bit - Sign is maintained + const __m256i rhs_vec_0123_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_vec_0123_0, m4b)); // B0(0-7) B1(0-7) B2(0-7) B3(0-7) + const __m256i rhs_vec_4567_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_vec_4567_0, m4b)); // B4(0-7) B5(0-7) B6(0-7) B7(0-7) + const __m256i rhs_vec_0123_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_vec_0123_1, m4b)); // B0(8-15) B1(8-15) B2(8-15) B3(8-15) + const __m256i rhs_vec_4567_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_vec_4567_1, m4b)); // B0(8-15) B1(8-15) B2(8-15) B3(8-15) + + const __m256i rhs_vec_0123_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_0, 4), m4b)); // B0(16-23) B1(16-23) B2(16-23) B3(16-23) + const __m256i rhs_vec_4567_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_0, 4), m4b)); // B4(16-23) B5(16-23) B6(16-23) B7(16-23) + const __m256i rhs_vec_0123_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_1, 4), m4b)); // B0(24-31) B1(24-31) B2(24-31) B3(24-31) + const __m256i rhs_vec_4567_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_1, 4), m4b)); // B4(24-31) B5(24-31) B6(24-31) B7(24-31) + + // Load the scale values for the 8 blocks interleaved in block_q4_0x8 + const __m256 col_scale_f32 = GGML_F32Cx8_REARRANGE_LOAD(b_ptr[b].d, changemask); + + // Load and convert to FP32 scale from block_q8_0 + const __m256 row_scale_f32 = _mm256_set1_ps(GGML_FP16_TO_FP32(a_ptr[b].d)); + + // Load the block values in block_q8_0 in batches of 16 bytes and replicate the same across 256 bit vector + __m256i lhs_vec_0 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)a_ptr[b].qs)); + __m256i lhs_vec_1 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 16))); + + lhs_vec_0 = _mm256_permute2f128_si256(lhs_vec_0, lhs_vec_0, 0); // A0 (0-15) A0(0-15) + lhs_vec_1 = _mm256_permute2f128_si256(lhs_vec_1, lhs_vec_1, 0); // A0 (16-31) A0(16-31)) + + __m256i iacc = _mm256_setzero_si256(); + + // Dot product done within 32 bit lanes and accumulated in the same vector + // B0(0-3) B4(0-3) B1(0-3) B5(0-3) B2(0-3) B6(0-3) B3(0-3) B7(0-3) with A0(0-3) + // B0(4-7) B4(4-7) B1(4-7) B5(4-7) B2(4-7) B6(4-7) B3(4-7) B7(4-7) with A0(4-7) + // ........................................................................... + // B0(28-31) B4(28-31) B1(28-31) B5(28-31) B2(28-31) B6(28-31) B3(28-31) B7(28-31) with A0(28-31) + + iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int(_mm256_blend_epi32(rhs_vec_0123_0 ,_mm256_shuffle_epi32(rhs_vec_4567_0, 177), 170), _mm256_shuffle_epi32(lhs_vec_0, 0))); + iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_0, 177) ,rhs_vec_4567_0, 170), _mm256_shuffle_epi32(lhs_vec_0, 85))); + + iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int(_mm256_blend_epi32(rhs_vec_0123_1 ,_mm256_shuffle_epi32(rhs_vec_4567_1, 177), 170), _mm256_shuffle_epi32(lhs_vec_0, 170))); + iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_1, 177) ,rhs_vec_4567_1, 170), _mm256_shuffle_epi32(lhs_vec_0, 255))); + + iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int(_mm256_blend_epi32(rhs_vec_0123_2 ,_mm256_shuffle_epi32(rhs_vec_4567_2, 177), 170), _mm256_shuffle_epi32(lhs_vec_1, 0))); + iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_2, 177) ,rhs_vec_4567_2, 170), _mm256_shuffle_epi32(lhs_vec_1, 85))); + + iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int(_mm256_blend_epi32(rhs_vec_0123_3 ,_mm256_shuffle_epi32(rhs_vec_4567_3, 177), 170), _mm256_shuffle_epi32(lhs_vec_1, 170))); + iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_3, 177) ,rhs_vec_4567_3, 170), _mm256_shuffle_epi32(lhs_vec_1, 255))); + + // Accumulated values multipled with appropriate scales + acc_row = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc), _mm256_mul_ps(col_scale_f32, row_scale_f32), acc_row); + } + + // Accumulated output values permuted so as to be stored in appropriate order post accumulation + acc_row = _mm256_permutevar8x32_ps(acc_row, finalpermutemask); + _mm256_storeu_ps(s + (y * nr + x * 8), acc_row); + } + } #else float sumf[8]; int sumi; @@ -2143,6 +2408,353 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * GGML_ASSERT((ggml_cpu_has_sve() || ggml_cpu_has_matmul_int8()) && "__ARM_FEATURE_SVE and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 quantization format for optimal " "performance"); +#elif defined(__AVX2__) || defined(__AVX512F__) + const block_q4_0x8 * b_ptr_start = (const block_q4_0x8 *)vx; + const block_q8_0x4 * a_ptr_start = (const block_q8_0x4 *)vy; + int64_t b_nb = n / QK4_0; + int64_t y = 0; + // Mask to mask out nibbles from packed bytes + const __m256i m4b = _mm256_set1_epi8(0x0F); + const __m128i loadMask = _mm_blend_epi32(_mm_setzero_si128(), _mm_set1_epi32(0xFFFFFFFF), 3); + // Lookup table to convert signed nibbles to signed bytes + __m256i signextendlut = _mm256_castsi128_si256(_mm_set_epi8(-1, -2, -3, -4, -5, -6, -7, -8, 7, 6, 5, 4, 3, 2, 1, 0)); + signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0); + // Permute mask used for easier vector processing at later stages + __m256i requiredOrder = _mm256_set_epi32(3 ,2 ,1 ,0, 7 ,6, 5, 4); + + // Take group of four block_q8_0x4 structures at each pass of the loop and perform dot product operation + int anr = nr - nr %16; // Used to align nr with boundary of 16 + + for (; y < anr / 4; y += 4) { + const block_q8_0x4 * a_ptrs[4]; + + a_ptrs[0] = a_ptr_start + (y * nb); + for (int i = 0; i < 3; ++i) { + a_ptrs[i + 1] = a_ptrs[i] + nb; + } + + // Take group of eight block_q4_0x8 structures at each pass of the loop and perform dot product operation + for (int64_t x = 0; x < nc / 8; x++) { + + const block_q4_0x8 * b_ptr = b_ptr_start + (x * b_nb); + + // Master FP accumulators + __m256 acc_rows[16]; + for (int i = 0; i < 16; i++) { + acc_rows[i] = _mm256_setzero_ps(); + } + + for (int64_t b = 0; b < nb; b++) { + // Load the eight block_q4_0 quantized values interleaved with each other in chunks of eight - B0,B1 ....B6,B7 + const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs)); + const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 32)); + const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 64)); + const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 96)); + + // Save the values in the following vectors in the formats B0B1B4B5, B2B3B6B7 for further processing and storing of values + const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240); + const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240); + const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240); + const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240); + + // 4-bit -> 8-bit - Sign is maintained + const __m256i rhs_mat_0145_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_0145_0, m4b)); //B0(0-7) B1(0-7) B4(0-7) B5(0-7) + const __m256i rhs_mat_2367_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_0, m4b)); //B2(0-7) B3(0-7) B6(0-7) B7(0-7) + + const __m256i rhs_mat_0145_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_0145_1, m4b)); //B0(8-15) B1(8-15) B4(8-15) B5(8-15) + const __m256i rhs_mat_2367_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_1, m4b)); //B2(8-15) B3(8-15) B6(8-15) B7(8-15) + + const __m256i rhs_mat_0145_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_0, 4), m4b)); //B0(16-23) B1(16-23) B4(16-23) B5(16-23) + const __m256i rhs_mat_2367_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_0, 4), m4b)); //B2(16-23) B3(16-23) B6(16-23) B7(16-23) + + const __m256i rhs_mat_0145_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_1, 4), m4b)); //B0(24-31) B1(24-31) B4(24-31) B5(24-31) + const __m256i rhs_mat_2367_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_1, 4), m4b)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31) + + // Shuffle pattern one - right side input + const __m256i rhs_mat_0145_0_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_0, 136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) + const __m256i rhs_mat_2367_0_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_0, 136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) + + const __m256i rhs_mat_0145_1_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_1, 136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) + const __m256i rhs_mat_2367_1_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_1, 136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) + + const __m256i rhs_mat_0145_2_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_2, 136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) + const __m256i rhs_mat_2367_2_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_2, 136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) + + const __m256i rhs_mat_0145_3_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_3, 136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) + const __m256i rhs_mat_2367_3_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_3, 136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) + + // Shuffle pattern two - right side input + + const __m256i rhs_mat_0145_0_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_0, 221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) + const __m256i rhs_mat_2367_0_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_0, 221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) + + const __m256i rhs_mat_0145_1_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_1, 221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) + const __m256i rhs_mat_2367_1_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_1, 221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) + + const __m256i rhs_mat_0145_2_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_2, 221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) + const __m256i rhs_mat_2367_2_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_2, 221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) + + const __m256i rhs_mat_0145_3_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_3, 221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) + const __m256i rhs_mat_2367_3_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_3, 221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) + + // Scale values - Load the wight scale values of block_q4_0x8 + const __m256 col_scale_f32 = GGML_F32Cx8_LOAD(b_ptr[b].d); + + // Process LHS in groups of four + for (int rp = 0; rp < 4; rp++) { + // Load the four block_q4_0 quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3 + // Loaded as set of 128 bit vectors and repeated into a 256 bit vector + __m256i lhs_mat_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs))); + __m256i lhs_mat_01_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 0); + __m256i lhs_mat_23_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 17); + __m256i lhs_mat_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 32))); + __m256i lhs_mat_01_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 0); + __m256i lhs_mat_23_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 17); + __m256i lhs_mat_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 64))); + __m256i lhs_mat_01_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 0); + __m256i lhs_mat_23_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 17); + __m256i lhs_mat_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 96))); + __m256i lhs_mat_01_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 0); + __m256i lhs_mat_23_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 17); + + // Shuffle pattern one - left side input + const __m256i lhs_mat_01_0_sp1 = _mm256_shuffle_epi32(lhs_mat_01_0, 160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) + const __m256i lhs_mat_23_0_sp1 = _mm256_shuffle_epi32(lhs_mat_23_0, 160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) + + const __m256i lhs_mat_01_1_sp1 = _mm256_shuffle_epi32(lhs_mat_01_1, 160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) + const __m256i lhs_mat_23_1_sp1 = _mm256_shuffle_epi32(lhs_mat_23_1, 160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) + + const __m256i lhs_mat_01_2_sp1 = _mm256_shuffle_epi32(lhs_mat_01_2, 160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) + const __m256i lhs_mat_23_2_sp1 = _mm256_shuffle_epi32(lhs_mat_23_2, 160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) + + const __m256i lhs_mat_01_3_sp1 = _mm256_shuffle_epi32(lhs_mat_01_3, 160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) + const __m256i lhs_mat_23_3_sp1 = _mm256_shuffle_epi32(lhs_mat_23_3, 160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) + + // Shuffle pattern two - left side input + const __m256i lhs_mat_01_0_sp2 = _mm256_shuffle_epi32(lhs_mat_01_0, 245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) + const __m256i lhs_mat_23_0_sp2 = _mm256_shuffle_epi32(lhs_mat_23_0, 245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) + + const __m256i lhs_mat_01_1_sp2 = _mm256_shuffle_epi32(lhs_mat_01_1, 245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) + const __m256i lhs_mat_23_1_sp2 = _mm256_shuffle_epi32(lhs_mat_23_1, 245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) + + const __m256i lhs_mat_01_2_sp2 = _mm256_shuffle_epi32(lhs_mat_01_2, 245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) + const __m256i lhs_mat_23_2_sp2 = _mm256_shuffle_epi32(lhs_mat_23_2, 245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) + + const __m256i lhs_mat_01_3_sp2 = _mm256_shuffle_epi32(lhs_mat_01_3, 245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) + const __m256i lhs_mat_23_3_sp2 = _mm256_shuffle_epi32(lhs_mat_23_3, 245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) + + // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane + // Resembles MMLAs into 2x2 matrices in ARM Version + __m256i iacc_mat_00_sp1 = + _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_01_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int(lhs_mat_01_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int(lhs_mat_01_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int(lhs_mat_01_0_sp1, rhs_mat_0145_0_sp1)); + __m256i iacc_mat_01_sp1 = + _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_01_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int(lhs_mat_01_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int(lhs_mat_01_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int(lhs_mat_01_0_sp1, rhs_mat_2367_0_sp1)); + __m256i iacc_mat_10_sp1 = + _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_23_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int(lhs_mat_23_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int(lhs_mat_23_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int(lhs_mat_23_0_sp1, rhs_mat_0145_0_sp1)); + __m256i iacc_mat_11_sp1 = + _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_23_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int(lhs_mat_23_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int(lhs_mat_23_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int(lhs_mat_23_0_sp1, rhs_mat_2367_0_sp1)); + __m256i iacc_mat_00_sp2 = + _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_01_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int(lhs_mat_01_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int(lhs_mat_01_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int(lhs_mat_01_0_sp2, rhs_mat_0145_0_sp2)); + __m256i iacc_mat_01_sp2 = + _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_01_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int(lhs_mat_01_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int(lhs_mat_01_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int(lhs_mat_01_0_sp2, rhs_mat_2367_0_sp2)); + __m256i iacc_mat_10_sp2 = + _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_23_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int(lhs_mat_23_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int(lhs_mat_23_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int(lhs_mat_23_0_sp2, rhs_mat_0145_0_sp2)); + __m256i iacc_mat_11_sp2 = + _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_23_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int(lhs_mat_23_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int(lhs_mat_23_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int(lhs_mat_23_0_sp2, rhs_mat_2367_0_sp2)); + + // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block + __m256i iacc_mat_00 = _mm256_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2); + __m256i iacc_mat_01 = _mm256_add_epi32(iacc_mat_01_sp1, iacc_mat_01_sp2); + __m256i iacc_mat_10 = _mm256_add_epi32(iacc_mat_10_sp1, iacc_mat_10_sp2); + __m256i iacc_mat_11 = _mm256_add_epi32(iacc_mat_11_sp1, iacc_mat_11_sp2); + + // Straighten out to make 4 row vectors + __m256i iacc_row_0 = _mm256_blend_epi32(iacc_mat_00, _mm256_shuffle_epi32(iacc_mat_01, 78), 204); + __m256i iacc_row_1 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01, 204); + __m256i iacc_row_2 = _mm256_blend_epi32(iacc_mat_10, _mm256_shuffle_epi32(iacc_mat_11, 78), 204); + __m256i iacc_row_3 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11, 204); + + // Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes + const __m256 row_scale_f32 = GGML_F32Cx8_REPEAT_LOAD(a_ptrs[rp][b].d, loadMask); + + // Multiply with appropiate scales and accumulate + acc_rows[rp * 4] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_0), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[rp * 4]); + acc_rows[rp * 4 + 1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_1), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[rp * 4 + 1]); + acc_rows[rp * 4 + 2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_2), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[rp * 4 + 2]); + acc_rows[rp * 4 + 3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_3), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[rp * 4 + 3]); + } + } + + // Store the accumulated values + for (int i = 0; i < 16; i++) { + _mm256_storeu_ps((float *)(s + ((y * 4 + i) * bs + x * 8)), acc_rows[i]); + } + } + } + + // Take a block_q8_0x4 structures at each pass of the loop and perform dot product operation + for (; y < nr / 4; y ++) { + + const block_q8_0x4 * a_ptr = a_ptr_start + (y * nb); + + // Load the eight block_q4_0 quantized values interleaved with each other in chunks of eight - B0,B1 ....B6,B7 + for (int64_t x = 0; x < nc / 8; x++) { + + const block_q4_0x8 * b_ptr = b_ptr_start + (x * b_nb); + + // Master FP accumulators + __m256 acc_rows[4]; + for (int i = 0; i < 4; i++) { + acc_rows[i] = _mm256_setzero_ps(); + } + + for (int64_t b = 0; b < nb; b++) { + // Load the eight block_q8_0 quantized values interleaved with each other in chunks of eight - B0,B1 ....B6,B7 + const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs)); + const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 32)); + const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 64)); + const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 96)); + + // Save the values in the following vectors in the formats B0B1B4B5, B2B3B6B7 for further processing and storing of valuess + const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240); + const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240); + const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240); + const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240); + + // 4-bit -> 8-bit - Sign is maintained + const __m256i rhs_mat_0145_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_0145_0, m4b)); //B0(0-7) B1(0-7) B4(0-7) B5(0-7) + const __m256i rhs_mat_2367_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_0, m4b)); //B2(0-7) B3(0-7) B6(0-7) B7(0-7) + + const __m256i rhs_mat_0145_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_0145_1, m4b)); //B0(8-15) B1(8-15) B4(8-15) B5(8-15) + const __m256i rhs_mat_2367_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_1, m4b)); //B2(8-15) B3(8-15) B6(8-15) B7(8-15) + + const __m256i rhs_mat_0145_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_0, 4), m4b)); //B0(16-23) B1(16-23) B4(16-23) B5(16-23) + const __m256i rhs_mat_2367_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_0, 4), m4b)); //B2(16-23) B3(16-23) B6(16-23) B7(16-23) + + const __m256i rhs_mat_0145_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_1, 4), m4b)); //B0(24-31) B1(24-31) B4(24-31) B5(24-31) + const __m256i rhs_mat_2367_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_1, 4), m4b)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31) + + // Shuffle pattern one - right side input + const __m256i rhs_mat_0145_0_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_0, 136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) + const __m256i rhs_mat_2367_0_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_0, 136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) + + const __m256i rhs_mat_0145_1_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_1, 136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) + const __m256i rhs_mat_2367_1_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_1, 136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) + + const __m256i rhs_mat_0145_2_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_2, 136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) + const __m256i rhs_mat_2367_2_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_2, 136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) + + const __m256i rhs_mat_0145_3_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_3, 136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) + const __m256i rhs_mat_2367_3_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_3, 136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) + + // Shuffle pattern two - right side input + + const __m256i rhs_mat_0145_0_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_0, 221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) + const __m256i rhs_mat_2367_0_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_0, 221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) + + const __m256i rhs_mat_0145_1_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_1, 221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) + const __m256i rhs_mat_2367_1_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_1, 221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) + + const __m256i rhs_mat_0145_2_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_2, 221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) + const __m256i rhs_mat_2367_2_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_2, 221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) + + const __m256i rhs_mat_0145_3_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_3, 221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) + const __m256i rhs_mat_2367_3_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_3, 221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) + + // Scale values - Load the wight scale values of block_q4_0x8 + const __m256 col_scale_f32 = GGML_F32Cx8_LOAD(b_ptr[b].d); + + // Load the four block_q4_0 quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3 + // Loaded as set of 128 bit vectors and repeated into a 256 bit vector + __m256i lhs_mat_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs))); + __m256i lhs_mat_01_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 0); + __m256i lhs_mat_23_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 17); + __m256i lhs_mat_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 32))); + __m256i lhs_mat_01_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 0); + __m256i lhs_mat_23_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 17); + __m256i lhs_mat_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 64))); + __m256i lhs_mat_01_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 0); + __m256i lhs_mat_23_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 17); + __m256i lhs_mat_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 96))); + __m256i lhs_mat_01_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 0); + __m256i lhs_mat_23_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 17); + + // Shuffle pattern one - left side input + + const __m256i lhs_mat_01_0_sp1 = _mm256_shuffle_epi32(lhs_mat_01_0, 160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) + const __m256i lhs_mat_23_0_sp1 = _mm256_shuffle_epi32(lhs_mat_23_0, 160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) + + const __m256i lhs_mat_01_1_sp1 = _mm256_shuffle_epi32(lhs_mat_01_1, 160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) + const __m256i lhs_mat_23_1_sp1 = _mm256_shuffle_epi32(lhs_mat_23_1, 160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) + + const __m256i lhs_mat_01_2_sp1 = _mm256_shuffle_epi32(lhs_mat_01_2, 160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) + const __m256i lhs_mat_23_2_sp1 = _mm256_shuffle_epi32(lhs_mat_23_2, 160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) + + const __m256i lhs_mat_01_3_sp1 = _mm256_shuffle_epi32(lhs_mat_01_3, 160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) + const __m256i lhs_mat_23_3_sp1 = _mm256_shuffle_epi32(lhs_mat_23_3, 160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) + + // Shuffle pattern two - left side input + + const __m256i lhs_mat_01_0_sp2 = _mm256_shuffle_epi32(lhs_mat_01_0, 245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) + const __m256i lhs_mat_23_0_sp2 = _mm256_shuffle_epi32(lhs_mat_23_0, 245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) + + const __m256i lhs_mat_01_1_sp2 = _mm256_shuffle_epi32(lhs_mat_01_1, 245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) + const __m256i lhs_mat_23_1_sp2 = _mm256_shuffle_epi32(lhs_mat_23_1, 245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) + + const __m256i lhs_mat_01_2_sp2 = _mm256_shuffle_epi32(lhs_mat_01_2, 245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) + const __m256i lhs_mat_23_2_sp2 = _mm256_shuffle_epi32(lhs_mat_23_2, 245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) + + const __m256i lhs_mat_01_3_sp2 = _mm256_shuffle_epi32(lhs_mat_01_3, 245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) + const __m256i lhs_mat_23_3_sp2 = _mm256_shuffle_epi32(lhs_mat_23_3, 245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) + + // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane + // Resembles MMLAs into 2x2 matrices in ARM Version + __m256i iacc_mat_00_sp1 = + _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_01_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int(lhs_mat_01_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int(lhs_mat_01_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int(lhs_mat_01_0_sp1, rhs_mat_0145_0_sp1)); + __m256i iacc_mat_01_sp1 = + _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_01_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int(lhs_mat_01_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int(lhs_mat_01_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int(lhs_mat_01_0_sp1, rhs_mat_2367_0_sp1)); + __m256i iacc_mat_10_sp1 = + _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_23_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int(lhs_mat_23_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int(lhs_mat_23_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int(lhs_mat_23_0_sp1, rhs_mat_0145_0_sp1)); + __m256i iacc_mat_11_sp1 = + _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_23_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int(lhs_mat_23_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int(lhs_mat_23_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int(lhs_mat_23_0_sp1, rhs_mat_2367_0_sp1)); + __m256i iacc_mat_00_sp2 = + _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_01_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int(lhs_mat_01_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int(lhs_mat_01_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int(lhs_mat_01_0_sp2, rhs_mat_0145_0_sp2)); + __m256i iacc_mat_01_sp2 = + _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_01_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int(lhs_mat_01_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int(lhs_mat_01_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int(lhs_mat_01_0_sp2, rhs_mat_2367_0_sp2)); + __m256i iacc_mat_10_sp2 = + _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_23_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int(lhs_mat_23_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int(lhs_mat_23_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int(lhs_mat_23_0_sp2, rhs_mat_0145_0_sp2)); + __m256i iacc_mat_11_sp2 = + _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_23_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int(lhs_mat_23_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int(lhs_mat_23_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int(lhs_mat_23_0_sp2, rhs_mat_2367_0_sp2)); + + // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block + __m256i iacc_mat_00 = _mm256_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2); + __m256i iacc_mat_01 = _mm256_add_epi32(iacc_mat_01_sp1, iacc_mat_01_sp2); + __m256i iacc_mat_10 = _mm256_add_epi32(iacc_mat_10_sp1, iacc_mat_10_sp2); + __m256i iacc_mat_11 = _mm256_add_epi32(iacc_mat_11_sp1, iacc_mat_11_sp2); + + + // Straighten out to make 4 row vectors + __m256i iacc_row_0 = _mm256_blend_epi32(iacc_mat_00, _mm256_shuffle_epi32(iacc_mat_01, 78), 204); + __m256i iacc_row_1 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01, 204); + __m256i iacc_row_2 = _mm256_blend_epi32(iacc_mat_10, _mm256_shuffle_epi32(iacc_mat_11, 78), 204); + __m256i iacc_row_3 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11, 204); + + // Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes + const __m256 row_scale_f32 = GGML_F32Cx8_REPEAT_LOAD(a_ptr[b].d, loadMask); + + // Multiply with appropiate scales and accumulate + acc_rows[0] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_0), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[0]); + acc_rows[1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_1), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[1]); + acc_rows[2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_2), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[2]); + acc_rows[3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_3), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[3]); + } + + // Store the accumulated values + for (int i = 0; i < 4; i++) { + _mm256_storeu_ps((float *)(s + ((y * 4 + i) * bs + x * 8)), acc_rows[i]); + } + } + } #else float sumf[4][8]; int sumi; From bdf314f38a2c90e18285f7d7067e8d736a14000a Mon Sep 17 00:00:00 2001 From: slaren Date: Thu, 5 Sep 2024 02:19:39 +0200 Subject: [PATCH 22/22] llama-bench : fix NUL terminators in CPU name (#9313) --- examples/llama-bench/llama-bench.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index bf97cc684..0c9bcb777 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -124,6 +124,9 @@ static std::string get_cpu_info() { (LPBYTE)cpu_brand, &cpu_brand_size) == ERROR_SUCCESS) { id.assign(cpu_brand, cpu_brand_size); + if (id.find('\0') != std::string::npos) { + id.resize(id.find('\0')); + } } RegCloseKey(hKey); #endif