From d67caea0d6e6c303d31b01d0a010973e6c908dff Mon Sep 17 00:00:00 2001 From: slaren Date: Thu, 6 Jun 2024 07:17:21 +0200 Subject: [PATCH 01/21] docker : add openmp lib (#7780) --- .devops/full-cuda.Dockerfile | 2 +- .devops/full.Dockerfile | 2 +- .devops/main-cuda.Dockerfile | 3 +++ .devops/main-vulkan.Dockerfile | 2 +- .devops/main.Dockerfile | 3 +++ .devops/server-cuda.Dockerfile | 2 +- .devops/server.Dockerfile | 2 +- 7 files changed, 11 insertions(+), 5 deletions(-) diff --git a/.devops/full-cuda.Dockerfile b/.devops/full-cuda.Dockerfile index c01006efe..f6073f662 100644 --- a/.devops/full-cuda.Dockerfile +++ b/.devops/full-cuda.Dockerfile @@ -12,7 +12,7 @@ FROM ${BASE_CUDA_DEV_CONTAINER} as build ARG CUDA_DOCKER_ARCH=all RUN apt-get update && \ - apt-get install -y build-essential python3 python3-pip git libcurl4-openssl-dev + apt-get install -y build-essential python3 python3-pip git libcurl4-openssl-dev libgomp1 COPY requirements.txt requirements.txt COPY requirements requirements diff --git a/.devops/full.Dockerfile b/.devops/full.Dockerfile index 6d5943a2f..6f19afa9c 100644 --- a/.devops/full.Dockerfile +++ b/.devops/full.Dockerfile @@ -3,7 +3,7 @@ ARG UBUNTU_VERSION=22.04 FROM ubuntu:$UBUNTU_VERSION as build RUN apt-get update && \ - apt-get install -y build-essential python3 python3-pip git libcurl4-openssl-dev + apt-get install -y build-essential python3 python3-pip git libcurl4-openssl-dev libgomp1 COPY requirements.txt requirements.txt COPY requirements requirements diff --git a/.devops/main-cuda.Dockerfile b/.devops/main-cuda.Dockerfile index 23f428944..5bcd45fe8 100644 --- a/.devops/main-cuda.Dockerfile +++ b/.devops/main-cuda.Dockerfile @@ -27,6 +27,9 @@ RUN make -j$(nproc) FROM ${BASE_CUDA_RUN_CONTAINER} as runtime +RUN apt-get update && \ + apt-get install -y libgomp1 + COPY --from=build /app/main /main ENTRYPOINT [ "/main" ] diff --git a/.devops/main-vulkan.Dockerfile b/.devops/main-vulkan.Dockerfile index 6c2b2ed5b..1bdb52803 100644 --- a/.devops/main-vulkan.Dockerfile +++ b/.devops/main-vulkan.Dockerfile @@ -3,7 +3,7 @@ ARG UBUNTU_VERSION=jammy FROM ubuntu:$UBUNTU_VERSION as build # Install build tools -RUN apt update && apt install -y git build-essential cmake wget +RUN apt update && apt install -y git build-essential cmake wget libgomp1 # Install Vulkan SDK RUN wget -qO - https://packages.lunarg.com/lunarg-signing-key-pub.asc | apt-key add - && \ diff --git a/.devops/main.Dockerfile b/.devops/main.Dockerfile index 763d75fce..98a58a4b9 100644 --- a/.devops/main.Dockerfile +++ b/.devops/main.Dockerfile @@ -13,6 +13,9 @@ RUN make -j$(nproc) FROM ubuntu:$UBUNTU_VERSION as runtime +RUN apt-get update && \ + apt-get install -y libgomp1 + COPY --from=build /app/main /main ENV LC_ALL=C.utf8 diff --git a/.devops/server-cuda.Dockerfile b/.devops/server-cuda.Dockerfile index 7f5228185..2532e69e8 100644 --- a/.devops/server-cuda.Dockerfile +++ b/.devops/server-cuda.Dockerfile @@ -30,7 +30,7 @@ RUN make -j$(nproc) FROM ${BASE_CUDA_RUN_CONTAINER} as runtime RUN apt-get update && \ - apt-get install -y libcurl4-openssl-dev + apt-get install -y libcurl4-openssl-dev libgomp1 COPY --from=build /app/server /server diff --git a/.devops/server.Dockerfile b/.devops/server.Dockerfile index 0d09d3627..a41c16b65 100644 --- a/.devops/server.Dockerfile +++ b/.devops/server.Dockerfile @@ -16,7 +16,7 @@ RUN make -j$(nproc) FROM ubuntu:$UBUNTU_VERSION as runtime RUN apt-get update && \ - apt-get install -y libcurl4-openssl-dev + apt-get install -y libcurl4-openssl-dev libgomp1 COPY --from=build /app/server /server From 2d08b7fbb483c14bd2b173d4cd51ea3a4f862e8f Mon Sep 17 00:00:00 2001 From: slaren Date: Thu, 6 Jun 2024 07:19:49 +0200 Subject: [PATCH 02/21] docker : build only main and server in their images (#7782) * add openmp lib to dockerfiles * build only main and server in their docker images --- .devops/main-cuda.Dockerfile | 2 +- .devops/main-rocm.Dockerfile | 2 +- .devops/main.Dockerfile | 2 +- .devops/server-cuda.Dockerfile | 2 +- .devops/server.Dockerfile | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.devops/main-cuda.Dockerfile b/.devops/main-cuda.Dockerfile index 5bcd45fe8..2aec4a85d 100644 --- a/.devops/main-cuda.Dockerfile +++ b/.devops/main-cuda.Dockerfile @@ -23,7 +23,7 @@ ENV CUDA_DOCKER_ARCH=${CUDA_DOCKER_ARCH} # Enable CUDA ENV LLAMA_CUDA=1 -RUN make -j$(nproc) +RUN make -j$(nproc) main FROM ${BASE_CUDA_RUN_CONTAINER} as runtime diff --git a/.devops/main-rocm.Dockerfile b/.devops/main-rocm.Dockerfile index 37576d68e..dcaeb3e72 100644 --- a/.devops/main-rocm.Dockerfile +++ b/.devops/main-rocm.Dockerfile @@ -40,6 +40,6 @@ ENV LLAMA_HIPBLAS=1 ENV CC=/opt/rocm/llvm/bin/clang ENV CXX=/opt/rocm/llvm/bin/clang++ -RUN make -j$(nproc) +RUN make -j$(nproc) main ENTRYPOINT [ "/app/main" ] diff --git a/.devops/main.Dockerfile b/.devops/main.Dockerfile index 98a58a4b9..d2514c4ba 100644 --- a/.devops/main.Dockerfile +++ b/.devops/main.Dockerfile @@ -9,7 +9,7 @@ WORKDIR /app COPY . . -RUN make -j$(nproc) +RUN make -j$(nproc) main FROM ubuntu:$UBUNTU_VERSION as runtime diff --git a/.devops/server-cuda.Dockerfile b/.devops/server-cuda.Dockerfile index 2532e69e8..4e9747b82 100644 --- a/.devops/server-cuda.Dockerfile +++ b/.devops/server-cuda.Dockerfile @@ -25,7 +25,7 @@ ENV LLAMA_CUDA=1 # Enable cURL ENV LLAMA_CURL=1 -RUN make -j$(nproc) +RUN make -j$(nproc) server FROM ${BASE_CUDA_RUN_CONTAINER} as runtime diff --git a/.devops/server.Dockerfile b/.devops/server.Dockerfile index a41c16b65..bee63b966 100644 --- a/.devops/server.Dockerfile +++ b/.devops/server.Dockerfile @@ -11,7 +11,7 @@ COPY . . ENV LLAMA_CURL=1 -RUN make -j$(nproc) +RUN make -j$(nproc) server FROM ubuntu:$UBUNTU_VERSION as runtime From f5d7b268ec4bf8628aa6ccc9f6631d0230dde76f Mon Sep 17 00:00:00 2001 From: Joan Fontanals Date: Thu, 6 Jun 2024 09:22:41 +0200 Subject: [PATCH 03/21] llama : add jina v2 base code (#7596) * feat: add changes to handle jina v2 base code * fix: do not complicate things * fix: fix the usage of the code model * fix: fix comments * fix: fix linting issues * fix: remove ollama patches * style : minor --------- Co-authored-by: Georgi Gerganov --- convert-hf-to-gguf-update.py | 1 + convert-hf-to-gguf.py | 7 ++++++- gguf-py/gguf/constants.py | 1 + gguf-py/gguf/tensor_mapping.py | 3 +++ llama.cpp | 17 +++++++++++++---- 5 files changed, 24 insertions(+), 5 deletions(-) diff --git a/convert-hf-to-gguf-update.py b/convert-hf-to-gguf-update.py index 6dae1a594..f43b15760 100755 --- a/convert-hf-to-gguf-update.py +++ b/convert-hf-to-gguf-update.py @@ -83,6 +83,7 @@ models = [ {"name": "jina-v2-es", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-es", }, {"name": "jina-v2-de", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-de", }, {"name": "smaug-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/abacusai/Smaug-Llama-3-70B-Instruct", }, + {"name": "jina-v2-code", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-code", }, ] diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index afb9704c8..a86864f04 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -475,6 +475,9 @@ class Model: if chkhsh == "c136ed14d01c2745d4f60a9596ae66800e2b61fa45643e72436041855ad4089d": # ref: https://huggingface.co/abacusai/Smaug-Llama-3-70B-Instruct res = "smaug-bpe" + if chkhsh == "7967bfa498ade6b757b064f31e964dddbb80f8f9a4d68d4ba7998fcf281c531a": + # ref: https://huggingface.co/jinaai/jina-embeddings-v2-base-code + res = "jina-v2-code" if res is None: logger.warning("\n") @@ -2452,11 +2455,13 @@ class JinaBertV2Model(BertModel): def get_tensors(self): for name, data in super().get_tensors(): - if 'gated_layers' in name: + if 'gated_layer' in name: d1 = data[:self.intermediate_size, :] name1 = name.replace('gated_layers', 'gated_layers_w') + name1 = name1.replace('up_gated_layer', 'gated_layers_v') d2 = data[self.intermediate_size:, :] name2 = name.replace('gated_layers', 'gated_layers_v') + name2 = name2.replace('up_gated_layer', 'gated_layers_w') yield name1, d1 yield name2, d2 continue diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index a3c024c89..8908585cc 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -415,6 +415,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.TOKEN_EMBD_NORM, MODEL_TENSOR.TOKEN_TYPES, + MODEL_TENSOR.ATTN_NORM_2, MODEL_TENSOR.ATTN_OUT_NORM, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_Q_NORM, diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 83e3c4c33..81b4992a5 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -102,6 +102,7 @@ class TensorNameMap: # Attention norm 2 MODEL_TENSOR.ATTN_NORM_2: ( "transformer.h.{bid}.ln_attn", # falcon40b + "encoder.layer.{bid}.layer_norm_1", # jina-v2-code ), # Attention query-key-value @@ -311,6 +312,7 @@ class TensorNameMap: "model.layers.{bid}.mlp.c_proj", # starcoder2 "encoder.layer.{bid}.mlp.wo", # jina-bert-v2 "model.layers.{bid}.residual_mlp.w2", # arctic + "encoder.layer.{bid}.mlp.down_layer", # jina-bert-v2 ), MODEL_TENSOR.FFN_DOWN_EXP: ( @@ -350,6 +352,7 @@ class TensorNameMap: "encoder.layers.{bid}.norm2", # nomic-bert "transformer.decoder_layer.{bid}.rms_norm_3", # Grok "encoder.layer.{bid}.mlp.layernorm", # jina-bert-v2 + "encoder.layer.{bid}.layer_norm_2" # jina-v2-code ), MODEL_TENSOR.SSM_IN: ( diff --git a/llama.cpp b/llama.cpp index 414d390e8..cefb4d1d5 100644 --- a/llama.cpp +++ b/llama.cpp @@ -704,6 +704,7 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, { LLM_TENSOR_TOKEN_TYPES, "token_types" }, + { LLM_TENSOR_ATTN_NORM_2, "blk.%d.attn_norm_2" }, { LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" }, { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, @@ -4653,8 +4654,7 @@ static void llm_load_vocab( LLAMA_LOG_WARN("%s: ************************************ \n", __func__); LLAMA_LOG_WARN("%s: \n", __func__); vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; - } else if ( - tokenizer_pre == "default") { + } else if (tokenizer_pre == "default") { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; } else if ( tokenizer_pre == "llama3" || @@ -4681,7 +4681,8 @@ static void llm_load_vocab( tokenizer_pre == "jina-es" || tokenizer_pre == "jina-de" || tokenizer_pre == "jina-v2-es" || - tokenizer_pre == "jina-v2-de") { + tokenizer_pre == "jina-v2-de" || + tokenizer_pre == "jina-v2-code") { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_GPT2; } else if ( tokenizer_pre == "refact") { @@ -5515,7 +5516,7 @@ static bool llm_load_tensors( layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}); } else { - layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); + layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); } layer.layer_out_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}); @@ -5556,6 +5557,9 @@ static bool llm_load_tensors( layer.attn_out_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}); //output_norm layer.attn_out_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd}); + layer.attn_norm_2 = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.attn_norm_2_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); @@ -8519,6 +8523,11 @@ struct llm_build_context { // attention layer norm cur = llm_build_norm(ctx0, cur, hparams, model.layers[il].attn_out_norm, model.layers[il].attn_out_norm_b, LLM_NORM, cb, il); + if (model.layers[il].attn_norm_2 != nullptr) { + cur = ggml_add(ctx0, cur, inpL); // re-add the layer input + cur = llm_build_norm(ctx0, cur, hparams, model.layers[il].attn_norm_2, model.layers[il].attn_norm_2_b, LLM_NORM, cb, il); + } + struct ggml_tensor * ffn_inp = cur; cb(ffn_inp, "ffn_inp", il); From 55b2d0849d3ec9e45e4a4d9e480f5fa7977872a6 Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Thu, 6 Jun 2024 10:07:06 +0100 Subject: [PATCH 04/21] grammars: x{min,max} repetition operator (#6640) * grammars: x{min,max} repetition operator + tweak +/*/? to avoid duplication of original over alternates * grammars: handle `x{n}` and fix `x{n,n}` * grammars: document new repetition operators * grammars: uniform use of int for min & max * grammars: refactor parser test * grammar: parsing tests w/ natural pretty print of updated expectations * grammars: much prettier print of expectations (+ TEST_GRAMMAR_PARSER_PRINT_ALL=1 to force all) * grammars: improve test pretty print again * grammars: pretty print rules and chars * grammars: fix copy rule skipping * grammars: disallow `a{,}` (not allowed in regexps) * Update common/grammar-parser.cpp Co-authored-by: Clint Herron * grammars: fix copy rule skipping (again) & display of expectations * grammars: more test cases * grammars: update reps parsing to bring ? / * / + closer to before * json: use new GBNF repetitions{m,n} syntax * grammars: update performance gotchas w/ repetition advice * Update examples/json_schema_to_grammar.py Co-authored-by: Clint Herron * Update examples/server/public/json-schema-to-grammar.mjs Co-authored-by: Clint Herron * grammars: comment on rule repetitions * grammars: ensure unambiguous number alternatives * grammar: nit typo switched error msgs * grammar: nit numbering in comment * json: update numeric rule to be unambiguous * Apply suggestions from code review Co-authored-by: Clint Herron * Update examples/server/public/json-schema-to-grammar.mjs Co-authored-by: Clint Herron * json: fix integral-part * grammar: add repetition tests --------- Co-authored-by: Clint Herron --- common/grammar-parser.cpp | 144 ++++- common/json-schema-to-grammar.cpp | 80 +-- examples/json_schema_to_grammar.py | 70 +-- examples/pydantic_models_to_grammar.py | 2 +- .../server/public/json-schema-to-grammar.mjs | 67 +- grammars/README.md | 12 +- tests/test-grammar-integration.cpp | 76 +++ tests/test-grammar-parser.cpp | 591 +++++++++++++----- tests/test-json-schema-to-grammar.cpp | 112 ++-- 9 files changed, 736 insertions(+), 418 deletions(-) diff --git a/common/grammar-parser.cpp b/common/grammar-parser.cpp index b5bc7d49b..79d2b0354 100644 --- a/common/grammar-parser.cpp +++ b/common/grammar-parser.cpp @@ -46,8 +46,12 @@ namespace grammar_parser { state.rules[rule_id] = rule; } + static bool is_digit_char(char c) { + return '0' <= c && c <= '9'; + } + static bool is_word_char(char c) { - return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9'); + return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || is_digit_char(c); } static std::pair parse_hex(const char * src, int size) { @@ -99,6 +103,17 @@ namespace grammar_parser { return pos; } + static const char * parse_int(const char * src) { + const char * pos = src; + while (is_digit_char(*pos)) { + pos++; + } + if (pos == src) { + throw std::runtime_error(std::string("expecting integer at ") + src); + } + return pos; + } + static std::pair parse_char(const char * src) { if (*src == '\\') { switch (src[1]) { @@ -137,6 +152,60 @@ namespace grammar_parser { bool is_nested) { size_t last_sym_start = out_elements.size(); const char * pos = src; + + auto handle_repetitions = [&](int min_times, int max_times) { + + if (last_sym_start == out_elements.size()) { + throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos); + } + + // apply transformation to previous symbol (last_sym_start to end) according to + // the following rewrite rules: + // S{m,n} --> S S S (m times) S'(n-m) + // S'(x) ::= S S'(x-1) | + // (... n-m definitions of these S' rules ...) + // S'(1) ::= S | + // S{m,} --> S S S (m times) S' + // S' ::= S S' | + // S* --> S{0,} + // --> S' ::= S S' | + // S+ --> S{1,} + // --> S S' + // S' ::= S S' | + // S? --> S{0,1} + // --> S' + // S' ::= S | + + std::vector previous_elements(out_elements.begin() + last_sym_start, out_elements.end()); + if (min_times == 0) { + out_elements.resize(last_sym_start); + } else { + // Repeat the previous elements (min_times - 1) times + for (int i = 1; i < min_times; i++) { + out_elements.insert(out_elements.end(), previous_elements.begin(), previous_elements.end()); + } + } + + uint32_t last_rec_rule_id = 0; + auto n_opt = max_times < 0 ? 1 : max_times - min_times; + + std::vector rec_rule(previous_elements); + for (int i = 0; i < n_opt; i++) { + rec_rule.resize(previous_elements.size()); + uint32_t rec_rule_id = generate_symbol_id(state, rule_name); + if (i > 0 || max_times < 0) { + rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id}); + } + rec_rule.push_back({LLAMA_GRETYPE_ALT, 0}); + rec_rule.push_back({LLAMA_GRETYPE_END, 0}); + add_rule(state, rec_rule_id, rec_rule); + last_rec_rule_id = rec_rule_id; + } + if (n_opt > 0) { + out_elements.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id}); + } + }; + while (*pos) { if (*pos == '"') { // literal string pos++; @@ -197,40 +266,47 @@ namespace grammar_parser { throw std::runtime_error(std::string("expecting ')' at ") + pos); } pos = parse_space(pos + 1, is_nested); - } else if (*pos == '*' || *pos == '+' || *pos == '?') { // repetition operator - if (last_sym_start == out_elements.size()) { - throw std::runtime_error(std::string("expecting preceding item to */+/? at ") + pos); - } - - // apply transformation to previous symbol (last_sym_start to end) according to - // rewrite rules: - // S* --> S' ::= S S' | - // S+ --> S' ::= S S' | S - // S? --> S' ::= S | - uint32_t sub_rule_id = generate_symbol_id(state, rule_name); - std::vector sub_rule; - // add preceding symbol to generated rule - sub_rule.insert( - sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end()); - if (*pos == '*' || *pos == '+') { - // cause generated rule to recurse - sub_rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id}); - } - // mark start of alternate def - sub_rule.push_back({LLAMA_GRETYPE_ALT, 0}); - if (*pos == '+') { - // add preceding symbol as alternate only for '+' (otherwise empty) - sub_rule.insert( - sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end()); - } - sub_rule.push_back({LLAMA_GRETYPE_END, 0}); - add_rule(state, sub_rule_id, sub_rule); - - // in original rule, replace previous symbol with reference to generated rule - out_elements.resize(last_sym_start); - out_elements.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id}); - + } else if (*pos == '*') { pos = parse_space(pos + 1, is_nested); + handle_repetitions(0, -1); + } else if (*pos == '+') { + pos = parse_space(pos + 1, is_nested); + handle_repetitions(1, -1); + } else if (*pos == '?') { + pos = parse_space(pos + 1, is_nested); + handle_repetitions(0, 1); + } else if (*pos == '{') { + pos = parse_space(pos + 1, is_nested); + + if (!is_digit_char(*pos)) { + throw std::runtime_error(std::string("expecting an int at ") + pos); + } + const char * int_end = parse_int(pos); + int min_times = std::stoul(std::string(pos, int_end - pos)); + pos = parse_space(int_end, is_nested); + + int max_times = -1; + + if (*pos == '}') { + max_times = min_times; + pos = parse_space(pos + 1, is_nested); + } else if (*pos == ',') { + pos = parse_space(pos + 1, is_nested); + + if (is_digit_char(*pos)) { + const char * int_end = parse_int(pos); + max_times = std::stoul(std::string(pos, int_end - pos)); + pos = parse_space(int_end, is_nested); + } + + if (*pos != '}') { + throw std::runtime_error(std::string("expecting '}' at ") + pos); + } + pos = parse_space(pos + 1, is_nested); + } else { + throw std::runtime_error(std::string("expecting ',' at ") + pos); + } + handle_repetitions(min_times, max_times); } else { break; } diff --git a/common/json-schema-to-grammar.cpp b/common/json-schema-to-grammar.cpp index 9a71f5d8d..737bae27c 100644 --- a/common/json-schema-to-grammar.cpp +++ b/common/json-schema-to-grammar.cpp @@ -16,58 +16,27 @@ static std::string join(Iterator begin, Iterator end, const std::string & separa static std::string repeat(const std::string & str, size_t n); -static std::string build_repetition(const std::string & item_rule, int min_items, int max_items, const std::string & separator_rule = "", bool item_rule_is_literal = false) { +static std::string build_repetition(const std::string & item_rule, int min_items, int max_items, const std::string & separator_rule = "") { + auto has_max = max_items != std::numeric_limits::max(); + + if (min_items == 0 && max_items == 1) { + return item_rule + "?"; + } + if (separator_rule.empty()) { - if (min_items == 0 && max_items == 1) { - return item_rule + "?"; - } else if (min_items == 1 && max_items == std::numeric_limits::max()) { + if (min_items == 1 && !has_max) { return item_rule + "+"; - } - } - - std::string result; - if (min_items > 0) { - if (item_rule_is_literal && separator_rule.empty()) { - result = "\"" + repeat(std::string(item_rule.begin() + 1, item_rule.end() - 1), min_items) + "\""; + } else if (min_items == 0 && !has_max) { + return item_rule + "*"; } else { - std::vector items(min_items, item_rule); - result = join(items.begin(), items.end(), separator_rule.empty() ? " " : " " + separator_rule + " "); + return item_rule + "{" + std::to_string(min_items) + "," + (has_max ? std::to_string(max_items) : "") + "}"; } } - std::function opt_repetitions = [&](int up_to_n, bool prefix_with_sep) -> std::string { - auto content = prefix_with_sep && !separator_rule.empty() ? separator_rule + " " + item_rule : item_rule; - - if (up_to_n == 0) { - return ""; - } else if (up_to_n == 1) { - return "(" + content + ")?"; - } else if (!separator_rule.empty() && !prefix_with_sep) { - return "(" + content + " " + opt_repetitions(up_to_n - 1, true) + ")?"; - } else { - std::string res = repeat("(" + content + " ", up_to_n); - // strip trailing space - res = res.substr(0, res.length() - 1); - res += repeat(")?", up_to_n); - return res; - } - }; - - if (min_items > 0 && max_items != min_items) { - result += " "; + auto result = item_rule + " " + build_repetition("(" + separator_rule + " " + item_rule + ")", min_items == 0 ? 0 : min_items - 1, has_max ? max_items - 1 : max_items); + if (min_items == 0) { + result = "(" + result + ")?"; } - - if (max_items != std::numeric_limits::max()) { - result += opt_repetitions(max_items - min_items, min_items > 0); - } else { - std::string item_operator = "(" + (separator_rule.empty() ? "" : separator_rule + " ") + item_rule + ")"; - if (min_items == 0 && !separator_rule.empty()) { - result = "(" + item_rule + " " + item_operator + "*)?"; - } else { - result += item_operator + "*"; - } - } - return result; } @@ -78,30 +47,24 @@ struct BuiltinRule { std::vector deps; }; -const std::string _up_to_15_digits = build_repetition("[0-9]", 0, 15); - std::unordered_map PRIMITIVE_RULES = { {"boolean", {"(\"true\" | \"false\") space", {}}}, - {"decimal-part", {"[0-9] " + _up_to_15_digits, {}}}, - {"integral-part", {"[0-9] | [1-9] " + _up_to_15_digits, {}}}, + {"decimal-part", {"[0-9]{1,16}", {}}}, + {"integral-part", {"[0] | [1-9] [0-9]{0,15}", {}}}, {"number", {"(\"-\"? integral-part) (\".\" decimal-part)? ([eE] [-+]? integral-part)? space", {"integral-part", "decimal-part"}}}, {"integer", {"(\"-\"? integral-part) space", {"integral-part"}}}, {"value", {"object | array | string | number | boolean | null", {"object", "array", "string", "number", "boolean", "null"}}}, {"object", {"\"{\" space ( string \":\" space value (\",\" space string \":\" space value)* )? \"}\" space", {"string", "value"}}}, {"array", {"\"[\" space ( value (\",\" space value)* )? \"]\" space", {"value"}}}, - {"uuid", {"\"\\\"\" [0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F] " - "\"-\" [0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F] " - "\"-\" [0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F] " - "\"-\" [0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F] " - "\"-\" [0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F] \"\\\"\" space", {}}}, - {"char", {"[^\"\\\\] | \"\\\\\" ([\"\\\\/bfnrt] | \"u\" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])", {}}}, + {"uuid", {"\"\\\"\" [0-9a-fA-F]{8} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{12} \"\\\"\" space", {}}}, + {"char", {"[^\"\\\\] | \"\\\\\" ([\"\\\\/bfnrt] | \"u\" [0-9a-fA-F]{4})", {}}}, {"string", {"\"\\\"\" char* \"\\\"\" space", {"char"}}}, {"null", {"\"null\" space", {}}}, }; std::unordered_map STRING_FORMAT_RULES = { - {"date", {"[0-9] [0-9] [0-9] [0-9] \"-\" ( \"0\" [1-9] | \"1\" [0-2] ) \"-\" ( \"0\" [1-9] | [1-2] [0-9] | \"3\" [0-1] )", {}}}, - {"time", {"([01] [0-9] | \"2\" [0-3]) \":\" [0-5] [0-9] \":\" [0-5] [0-9] ( \".\" [0-9] [0-9] [0-9] )? ( \"Z\" | ( \"+\" | \"-\" ) ( [01] [0-9] | \"2\" [0-3] ) \":\" [0-5] [0-9] )", {}}}, + {"date", {"[0-9]{4} \"-\" ( \"0\" [1-9] | \"1\" [0-2] ) \"-\" ( \"0\" [1-9] | [1-2] [0-9] | \"3\" [0-1] )", {}}}, + {"time", {"([01] [0-9] | \"2\" [0-3]) \":\" [0-5] [0-9] \":\" [0-5] [0-9] ( \".\" [0-9]{3} )? ( \"Z\" | ( \"+\" | \"-\" ) ( [01] [0-9] | \"2\" [0-3] ) \":\" [0-5] [0-9] )", {}}}, {"date-time", {"date \"T\" time", {"date", "time"}}}, {"date-string", {"\"\\\"\" date \"\\\"\" space", {"date"}}}, {"time-string", {"\"\\\"\" time \"\\\"\" space", {"time"}}}, @@ -385,8 +348,7 @@ private: sub_is_literal ? "\"" + sub + "\"" : sub, min_times, max_times, - "", - sub_is_literal + "" ); seq.back().second = false; } else { diff --git a/examples/json_schema_to_grammar.py b/examples/json_schema_to_grammar.py index 826cd3f72..7d889c3fe 100755 --- a/examples/json_schema_to_grammar.py +++ b/examples/json_schema_to_grammar.py @@ -6,52 +6,22 @@ import re import sys from typing import Any, Dict, List, Set, Tuple, Union -def _build_repetition(item_rule, min_items, max_items, separator_rule=None, item_rule_is_literal=False): + +def _build_repetition(item_rule, min_items, max_items, separator_rule=None): + + if min_items == 0 and max_items == 1: + return f'{item_rule}?' + if not separator_rule: - if min_items == 0 and max_items == 1: - return f'{item_rule}?' - elif min_items == 1 and max_items is None: + if min_items == 1 and max_items is None: return f'{item_rule}+' - - result = '' - - if min_items > 0: - if item_rule_is_literal and separator_rule is None: - result = '"' + (item_rule[1:-1] * min_items) + '"' + elif min_items == 0 and max_items is None: + return f'{item_rule}*' else: - result = (f' {separator_rule} ' if separator_rule else ' ').join([item_rule] * min_items) + return f'{item_rule}{{{min_items},{max_items if max_items is not None else ""}}}' - def opt_repetitions(up_to_n, prefix_with_sep=False): - ''' - - n=4, no sep: '(a (a (a (a)?)?)?)?' - - n=4, sep=',', prefix: '("," a ("," a ("," a ("," a)?)?)?)?' - - n=4, sep=',', no prefix: '(a ("," a ("," a ("," a)?)?)?)?' - ''' - - content = f'{separator_rule} {item_rule}' if prefix_with_sep and separator_rule else item_rule - if up_to_n == 0: - return '' - elif up_to_n == 1: - return f'({content})?' - elif separator_rule and not prefix_with_sep: - return f'({content} {opt_repetitions(up_to_n - 1, prefix_with_sep=True)})?' - else: - return (f'({content} ' * up_to_n).rstrip() + (')?' * up_to_n) - - if min_items > 0 and max_items != min_items: - result += ' ' - - if max_items is not None: - result += opt_repetitions(max_items - min_items, prefix_with_sep=min_items > 0) - else: - item_operator = f'({separator_rule + " " if separator_rule else ""}{item_rule})' - - if min_items == 0 and separator_rule: - result = f'({item_rule} {item_operator}*)?' - else: - result += f'{item_operator}*' - - return result + result = item_rule + ' ' + _build_repetition(f'({separator_rule} {item_rule})', min_items - 1 if min_items > 0 else 0, max_items - 1 if max_items is not None else None) + return f'({result})?' if min_items == 0 else result class BuiltinRule: @@ -59,31 +29,29 @@ class BuiltinRule: self.content = content self.deps = deps or [] -_up_to_15_digits = _build_repetition('[0-9]', 0, 15) - # whitespace is constrained to a single space char to prevent model "running away" in # whitespace. Also maybe improves generation quality? SPACE_RULE = '" "?' PRIMITIVE_RULES = { 'boolean' : BuiltinRule('("true" | "false") space', []), - 'decimal-part' : BuiltinRule('[0-9] ' + _up_to_15_digits, []), - 'integral-part': BuiltinRule('[0-9] | [1-9] ' + _up_to_15_digits, []), + 'decimal-part' : BuiltinRule('[0-9]{1,16}', []), + 'integral-part': BuiltinRule('[0] | [1-9] [0-9]{0,15}', []), 'number' : BuiltinRule('("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space', ['integral-part', 'decimal-part']), 'integer' : BuiltinRule('("-"? integral-part) space', ['integral-part']), 'value' : BuiltinRule('object | array | string | number | boolean | null', ['object', 'array', 'string', 'number', 'boolean', 'null']), 'object' : BuiltinRule('"{" space ( string ":" space value ("," space string ":" space value)* )? "}" space', ['string', 'value']), 'array' : BuiltinRule('"[" space ( value ("," space value)* )? "]" space', ['value']), - 'uuid' : BuiltinRule(r'"\"" ' + ' "-" '.join('[0-9a-fA-F]' * n for n in [8, 4, 4, 4, 12]) + r' "\"" space', []), - 'char' : BuiltinRule(r'[^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])', []), + 'uuid' : BuiltinRule(r'"\"" [0-9a-fA-F]{8} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{12} "\"" space', []), + 'char' : BuiltinRule(r'[^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F]{4})', []), 'string' : BuiltinRule(r'"\"" char* "\"" space', ['char']), 'null' : BuiltinRule('"null" space', []), } # TODO: support "uri", "email" string formats STRING_FORMAT_RULES = { - 'date' : BuiltinRule('[0-9] [0-9] [0-9] [0-9] "-" ( "0" [1-9] | "1" [0-2] ) "-" ( \"0\" [1-9] | [1-2] [0-9] | "3" [0-1] )', []), - 'time' : BuiltinRule('([01] [0-9] | "2" [0-3]) ":" [0-5] [0-9] ":" [0-5] [0-9] ( "." [0-9] [0-9] [0-9] )? ( "Z" | ( "+" | "-" ) ( [01] [0-9] | "2" [0-3] ) ":" [0-5] [0-9] )', []), + 'date' : BuiltinRule('[0-9]{4} "-" ( "0" [1-9] | "1" [0-2] ) "-" ( \"0\" [1-9] | [1-2] [0-9] | "3" [0-1] )', []), + 'time' : BuiltinRule('([01] [0-9] | "2" [0-3]) ":" [0-5] [0-9] ":" [0-5] [0-9] ( "." [0-9]{3} )? ( "Z" | ( "+" | "-" ) ( [01] [0-9] | "2" [0-3] ) ":" [0-5] [0-9] )', []), 'date-time' : BuiltinRule('date "T" time', ['date', 'time']), 'date-string' : BuiltinRule('"\\"" date "\\"" space', ['date']), 'time-string' : BuiltinRule('"\\"" time "\\"" space', ['time']), @@ -333,7 +301,7 @@ class SchemaConverter: sub_rule_ids[sub] = id sub = id - seq[-1] = (_build_repetition(f'"{sub}"' if sub_is_literal else sub, min_times, max_times, item_rule_is_literal=sub_is_literal), False) + seq[-1] = (_build_repetition(f'"{sub}"' if sub_is_literal else sub, min_times, max_times), False) else: literal = '' while i < length: diff --git a/examples/pydantic_models_to_grammar.py b/examples/pydantic_models_to_grammar.py index 9acc7cc6d..f029c73a2 100644 --- a/examples/pydantic_models_to_grammar.py +++ b/examples/pydantic_models_to_grammar.py @@ -624,7 +624,7 @@ string ::= "\"" ( "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) )* "\"" ws ws ::= ([ \t\n] ws)? -float ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws +float ::= ("-"? ([0] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws integer ::= [0-9]+""" diff --git a/examples/server/public/json-schema-to-grammar.mjs b/examples/server/public/json-schema-to-grammar.mjs index 8e0be1b40..cef11eab8 100644 --- a/examples/server/public/json-schema-to-grammar.mjs +++ b/examples/server/public/json-schema-to-grammar.mjs @@ -2,57 +2,26 @@ const SPACE_RULE = '" "?'; function _buildRepetition(itemRule, minItems, maxItems, opts={}) { + if (minItems === 0 && maxItems === 1) { + return `${itemRule}?`; + } + + const separatorRule = opts.separatorRule ?? ''; const itemRuleIsLiteral = opts.itemRuleIsLiteral ?? false if (separatorRule === '') { - if (minItems === 0 && maxItems === 1) { - return `${itemRule}?`; - } else if (minItems === 1 && maxItems === undefined) { + if (minItems === 1 && maxItems === undefined) { return `${itemRule}+`; - } - } - - let result = ''; - if (minItems > 0) { - if (itemRuleIsLiteral && separatorRule === '') { - result = `"${itemRule.slice(1, -1).repeat(minItems)}"`; + } else if (minItems === 0 && maxItems === undefined) { + return `${itemRule}*`; } else { - result = Array.from({ length: minItems }, () => itemRule) - .join(separatorRule !== '' ? ` ${separatorRule} ` : ' '); + return `${itemRule}{${minItems},${maxItems !== undefined ? maxItems : ''}}`; } } - const optRepetitions = (upToN, prefixWithSep=false) => { - const content = separatorRule !== '' && prefixWithSep ? `${separatorRule} ${itemRule}` : itemRule; - if (upToN === 0) { - return ''; - } else if (upToN === 1) { - return `(${content})?`; - } else if (separatorRule !== '' && !prefixWithSep) { - return `(${content} ${optRepetitions(upToN - 1, true)})?`; - } else { - return Array.from({ length: upToN }, () => `(${content}`).join(' ').trim() + Array.from({ length: upToN }, () => ')?').join(''); - } - }; - - if (minItems > 0 && maxItems !== minItems) { - result += ' '; - } - - if (maxItems !== undefined) { - result += optRepetitions(maxItems - minItems, minItems > 0); - } else { - const itemOperator = `(${separatorRule !== '' ? separatorRule + ' ' : ''}${itemRule})`; - - if (minItems === 0 && separatorRule !== '') { - result = `(${itemRule} ${itemOperator}*)?`; - } else { - result += `${itemOperator}*`; - } - } - - return result; + const result = itemRule + ' ' + _buildRepetition(`(${separatorRule} ${itemRule})`, minItems > 0 ? minItems - 1 : 0, maxItems !== undefined ? maxItems - 1 : undefined); + return minItems === 0 ? `(${result})?` : result; } class BuiltinRule { @@ -62,27 +31,25 @@ class BuiltinRule { } } -const UP_TO_15_DIGITS = _buildRepetition('[0-9]', 0, 15); - const PRIMITIVE_RULES = { boolean : new BuiltinRule('("true" | "false") space', []), - 'decimal-part' : new BuiltinRule('[0-9] ' + UP_TO_15_DIGITS, []), - 'integral-part': new BuiltinRule('[0-9] | [1-9] ' + UP_TO_15_DIGITS, []), + 'decimal-part' : new BuiltinRule('[0-9]{1,16}', []), + 'integral-part': new BuiltinRule('[0] | [1-9] [0-9]{0,15}', []), number : new BuiltinRule('("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space', ['integral-part', 'decimal-part']), integer : new BuiltinRule('("-"? integral-part) space', ['integral-part']), value : new BuiltinRule('object | array | string | number | boolean | null', ['object', 'array', 'string', 'number', 'boolean', 'null']), object : new BuiltinRule('"{" space ( string ":" space value ("," space string ":" space value)* )? "}" space', ['string', 'value']), array : new BuiltinRule('"[" space ( value ("," space value)* )? "]" space', ['value']), - uuid : new BuiltinRule('"\\"" ' + [8, 4, 4, 4, 12].map(n => [...new Array(n)].map(_ => '[0-9a-fA-F]').join('')).join(' "-" ') + ' "\\"" space', []), - char : new BuiltinRule(`[^"\\\\] | "\\\\" (["\\\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])`, []), + uuid : new BuiltinRule('"\\"" [0-9a-fA-F]{8} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{12} "\\"" space', []), + char : new BuiltinRule(`[^"\\\\] | "\\\\" (["\\\\/bfnrt] | "u" [0-9a-fA-F]{4})`, []), string : new BuiltinRule(`"\\"" char* "\\"" space`, ['char']), null : new BuiltinRule('"null" space', []), }; // TODO: support "uri", "email" string formats const STRING_FORMAT_RULES = { - 'date' : new BuiltinRule('[0-9] [0-9] [0-9] [0-9] "-" ( "0" [1-9] | "1" [0-2] ) "-" ( \"0\" [1-9] | [1-2] [0-9] | "3" [0-1] )', []), - 'time' : new BuiltinRule('([01] [0-9] | "2" [0-3]) ":" [0-5] [0-9] ":" [0-5] [0-9] ( "." [0-9] [0-9] [0-9] )? ( "Z" | ( "+" | "-" ) ( [01] [0-9] | "2" [0-3] ) ":" [0-5] [0-9] )', []), + 'date' : new BuiltinRule('[0-9]{4} "-" ( "0" [1-9] | "1" [0-2] ) "-" ( \"0\" [1-9] | [1-2] [0-9] | "3" [0-1] )', []), + 'time' : new BuiltinRule('([01] [0-9] | "2" [0-3]) ":" [0-5] [0-9] ":" [0-5] [0-9] ( "." [0-9]{3} )? ( "Z" | ( "+" | "-" ) ( [01] [0-9] | "2" [0-3] ) ":" [0-5] [0-9] )', []), 'date-time' : new BuiltinRule('date "T" time', ['date', 'time']), 'date-string' : new BuiltinRule('"\\"" date "\\"" space', ['date']), 'time-string' : new BuiltinRule('"\\"" time "\\"" space', ['time']), diff --git a/grammars/README.md b/grammars/README.md index 2b8384d9d..3ffc7cec0 100644 --- a/grammars/README.md +++ b/grammars/README.md @@ -59,9 +59,13 @@ Parentheses `()` can be used to group sequences, which allows for embedding alte ## Repetition and Optional Symbols -- `*` after a symbol or sequence means that it can be repeated zero or more times. -- `+` denotes that the symbol or sequence should appear one or more times. -- `?` makes the preceding symbol or sequence optional. +- `*` after a symbol or sequence means that it can be repeated zero or more times (equivalent to `{0,}`). +- `+` denotes that the symbol or sequence should appear one or more times (equivalent to `{1,}`). +- `?` makes the preceding symbol or sequence optional (equivalent to `{0,1}`). +- `{m}` repeats the precedent symbol or sequence exactly `m` times +- `{m,}` repeats the precedent symbol or sequence at least `m` times +- `{m,n}` repeats the precedent symbol or sequence at between `m` and `n` times (included) +- `{0,n}` repeats the precedent symbol or sequence at most `n` times (included) ## Comments and newlines @@ -98,4 +102,4 @@ Grammars currently have performance gotchas (see https://github.com/ggerganov/ll A common pattern is to allow repetitions of a pattern `x` up to N times. -While semantically correct, the syntax `x? x? x?.... x?` (with N repetitions) will result in extremely slow inference. Instead, you can write `(x (x (x ... (x)?...)?)?)?` (w/ N-deep nesting) +While semantically correct, the syntax `x? x? x?.... x?` (with N repetitions) may result in extremely slow sampling. Instead, you can write `x{0,N}` (or `(x (x (x ... (x)?...)?)?)?` w/ N-deep nesting in earlier llama.cpp versions). diff --git a/tests/test-grammar-integration.cpp b/tests/test-grammar-integration.cpp index 01c5bb27a..9bdab05af 100644 --- a/tests/test-grammar-integration.cpp +++ b/tests/test-grammar-integration.cpp @@ -292,6 +292,82 @@ static void test_quantifiers() { "catyyy", } ); + test_grammar( + "simple exact repetition", + // Grammar + R"""( + root ::= [ab]{4} + )""", + // Passing strings + { + "aaaa", + "bbbb", + "abab", + }, + // Failing strings + { + "a", + "b", + "aaaaa", + } + ); + test_grammar( + "simple min repetition", + // Grammar + R"""( + root ::= [ab]{4,} + )""", + // Passing strings + { + "aaaa", + "aaaaab", + "bbbb", + "ababab", + }, + // Failing strings + { + "", + "aba", + } + ); + test_grammar( + "simple max repetition", + // Grammar + R"""( + root ::= [ab]{0,4} + )""", + // Passing strings + { + "", + "a", + "aa", + "aaa", + "aaab", + }, + // Failing strings + { + "aaaaa", + } + ); + test_grammar( + "min / max repetition", + // Grammar + R"""( + root ::= ("0x" [A-F0-9]{2} " "?){3,5} + )""", + // Passing strings + { + "0xFF 0x12 0xAB", + "0xFF 0x12 0xAB 0x00 0x00", + }, + // Failing strings + { + "", + "0xFF", + "0xFF 0x12", + "0xFF 0x12 0xAB 0x00 0x00 0x00", + } + ); } static void test_failure_missing_root() { diff --git a/tests/test-grammar-parser.cpp b/tests/test-grammar-parser.cpp index 91939e276..5df5abb25 100644 --- a/tests/test-grammar-parser.cpp +++ b/tests/test-grammar-parser.cpp @@ -7,28 +7,79 @@ #include -int main() -{ - grammar_parser::parse_state parsed_grammar; +static const char * type_str(llama_gretype type) { + switch (type) { + case LLAMA_GRETYPE_CHAR: return "LLAMA_GRETYPE_CHAR"; + case LLAMA_GRETYPE_CHAR_NOT: return "LLAMA_GRETYPE_CHAR_NOT"; + case LLAMA_GRETYPE_CHAR_ALT: return "LLAMA_GRETYPE_CHAR_ALT"; + case LLAMA_GRETYPE_CHAR_RNG_UPPER: return "LLAMA_GRETYPE_CHAR_RNG_UPPER"; + case LLAMA_GRETYPE_RULE_REF: return "LLAMA_GRETYPE_RULE_REF"; + case LLAMA_GRETYPE_ALT: return "LLAMA_GRETYPE_ALT"; + case LLAMA_GRETYPE_END: return "LLAMA_GRETYPE_END"; + default: return "?"; + } +} - const char *grammar_bytes = R"""(root ::= (expr "=" term "\n")+ -expr ::= term ([-+*/] term)* -term ::= [0-9]+)"""; +static void verify_parsing(const char *grammar_bytes, const std::vector> expected, const std::vector &expected_rules) { + uint32_t index = 0; + grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_bytes); - parsed_grammar = grammar_parser::parse(grammar_bytes); + std::map symbol_names; + for (auto it = parsed_grammar.symbol_ids.begin(); it != parsed_grammar.symbol_ids.end(); ++it) { + symbol_names[it->second] = it->first; + } - std::vector> expected = { - {"expr", 2}, - {"expr_5", 5}, - {"expr_6", 6}, - {"root", 0}, - {"root_1", 1}, - {"root_4", 4}, - {"term", 3}, - {"term_7", 7}, + auto print_all = [&]() { + fprintf(stderr, " verify_parsing(R\"\"\"(%s)\"\"\", {\n", grammar_bytes); + for (auto it = parsed_grammar.symbol_ids.begin(); it != parsed_grammar.symbol_ids.end(); ++it) { + fprintf(stderr, " {\"%s\", %u},\n", it->first.c_str(), it->second); + } + fprintf(stderr, " }, {\n"); + for (size_t i_rule = 0; i_rule < parsed_grammar.rules.size(); i_rule++) { + fprintf(stderr, " // %s (index %zu)\n", symbol_names[i_rule].c_str(), i_rule); + auto & rule = parsed_grammar.rules[i_rule]; + for (uint32_t i = 0; i < rule.size(); i++) { + std::string rule_str; + fprintf(stderr, " {%s, ", type_str(rule[i].type)); + if (rule[i].type == LLAMA_GRETYPE_CHAR || rule[i].type == LLAMA_GRETYPE_CHAR_ALT || + rule[i].type == LLAMA_GRETYPE_CHAR_NOT || rule[i].type == LLAMA_GRETYPE_CHAR_RNG_UPPER) { + char c = rule[i].value; + if (c == '\n') { + fprintf(stderr, "'\\n'"); + } else if (c == '\t') { + fprintf(stderr, "'\\t'"); + } else if (c == '\r') { + fprintf(stderr, "'\\r'"); + } else if (c == '\0') { + fprintf(stderr, "'\\0'"); + } else { + fprintf(stderr, "'%c'", c); + } + } else if (rule[i].type == LLAMA_GRETYPE_RULE_REF) { + fprintf(stderr, "/* %s */ %u", symbol_names[rule[i].value].c_str(), rule[i].value); + } else { + fprintf(stderr, "%u", rule[i].value); + } + fprintf(stderr, "},\n"); + } + } + fprintf(stderr, " });\n"); }; - uint32_t index = 0; + if (getenv("TEST_GRAMMAR_PARSER_PRINT_ALL")) { + print_all(); + fprintf(stderr, "\n"); + return; + } + + fprintf(stderr, "Testing grammar:%s\n", grammar_bytes); + + if (parsed_grammar.symbol_ids.size() != expected.size()) { + fprintf(stderr, "Code to update expectation (set TEST_GRAMMAR_PARSER_PRINT_ALL=1 to print all):\n"); + print_all(); + assert(parsed_grammar.symbol_ids.size() == expected.size()); + } + for (auto it = parsed_grammar.symbol_ids.begin(); it != parsed_grammar.symbol_ids.end(); ++it) { std::string key = it->first; @@ -38,51 +89,18 @@ term ::= [0-9]+)"""; // pretty print error message before asserting if (expected_pair.first != key || expected_pair.second != value) { + fprintf(stderr, "index: %u\n", index); fprintf(stderr, "expected_pair: %s, %u\n", expected_pair.first.c_str(), expected_pair.second); fprintf(stderr, "actual_pair: %s, %u\n", key.c_str(), value); fprintf(stderr, "expected_pair != actual_pair\n"); + fprintf(stderr, "Code to update expectation (set TEST_GRAMMAR_PARSER_PRINT_ALL=1 to print all):\n"); + print_all(); } assert(expected_pair.first == key && expected_pair.second == value); index++; } - std::vector expected_rules = { - {LLAMA_GRETYPE_RULE_REF, 4}, - {LLAMA_GRETYPE_END, 0}, - {LLAMA_GRETYPE_RULE_REF, 2}, - {LLAMA_GRETYPE_CHAR, 61}, - {LLAMA_GRETYPE_RULE_REF, 3}, - {LLAMA_GRETYPE_CHAR, 10}, - {LLAMA_GRETYPE_END, 0}, - {LLAMA_GRETYPE_RULE_REF, 3}, - {LLAMA_GRETYPE_RULE_REF, 6}, - {LLAMA_GRETYPE_END, 0}, - {LLAMA_GRETYPE_RULE_REF, 7}, - {LLAMA_GRETYPE_END, 0}, - {LLAMA_GRETYPE_RULE_REF, 1}, - {LLAMA_GRETYPE_RULE_REF, 4}, - {LLAMA_GRETYPE_ALT, 0}, - {LLAMA_GRETYPE_RULE_REF, 1}, - {LLAMA_GRETYPE_END, 0}, - {LLAMA_GRETYPE_CHAR, 45}, - {LLAMA_GRETYPE_CHAR_ALT, 43}, - {LLAMA_GRETYPE_CHAR_ALT, 42}, - {LLAMA_GRETYPE_CHAR_ALT, 47}, - {LLAMA_GRETYPE_RULE_REF, 3}, - {LLAMA_GRETYPE_END, 0}, - {LLAMA_GRETYPE_RULE_REF, 5}, - {LLAMA_GRETYPE_RULE_REF, 6}, - {LLAMA_GRETYPE_ALT, 0}, - {LLAMA_GRETYPE_END, 0}, - {LLAMA_GRETYPE_CHAR, 48}, - {LLAMA_GRETYPE_CHAR_RNG_UPPER, 57}, - {LLAMA_GRETYPE_RULE_REF, 7}, - {LLAMA_GRETYPE_ALT, 0}, - {LLAMA_GRETYPE_CHAR, 48}, - {LLAMA_GRETYPE_CHAR_RNG_UPPER, 57}, - {LLAMA_GRETYPE_END, 0}, - }; index = 0; for (auto rule : parsed_grammar.rules) @@ -97,28 +115,306 @@ term ::= [0-9]+)"""; if (expected_element.type != element.type || expected_element.value != element.value) { fprintf(stderr, "index: %u\n", index); - fprintf(stderr, "expected_element: %d, %u\n", expected_element.type, expected_element.value); - fprintf(stderr, "actual_element: %d, %u\n", element.type, element.value); + fprintf(stderr, "expected_element: %s, %u\n", type_str(expected_element.type), expected_element.value); + fprintf(stderr, "actual_element: %s, %u\n", type_str(element.type), element.value); fprintf(stderr, "expected_element != actual_element\n"); + fprintf(stderr, "all elements:\n"); + fprintf(stderr, "Code to update expectation (set TEST_GRAMMAR_PARSER_PRINT_ALL=1 to print all):\n"); + print_all(); } assert(expected_element.type == element.type && expected_element.value == element.value); index++; } } +} - const char *longer_grammar_bytes = R"""( - root ::= (expr "=" ws term "\n")+ - expr ::= term ([-+*/] term)* - term ::= ident | num | "(" ws expr ")" ws - ident ::= [a-z] [a-z0-9_]* ws - num ::= [0-9]+ ws - ws ::= [ \t\n]* - )"""; +static void verify_failure(const char *grammar_bytes) { + fprintf(stderr, "Testing expected failure:%s\n", grammar_bytes); + auto result = grammar_parser::parse(grammar_bytes); + assert(result.rules.empty() && "should have failed"); +} - parsed_grammar = grammar_parser::parse(longer_grammar_bytes); +int main() +{ + verify_failure(R"""( + root ::= "a"{,}" + )"""); - expected = { + verify_failure(R"""( + root ::= "a"{,10}" + )"""); + + verify_parsing(R"""( + root ::= "a" + )""", { + {"root", 0}, + }, { + // root (index 0) + {LLAMA_GRETYPE_CHAR, 'a'}, + {LLAMA_GRETYPE_END, 0}, + }); + + verify_parsing(R"""( + root ::= "a" | [bdx-z] | [^1-3] + )""", { + {"root", 0}, + }, { + // root (index 0) + {LLAMA_GRETYPE_CHAR, 'a'}, + {LLAMA_GRETYPE_ALT, 0}, + {LLAMA_GRETYPE_CHAR, 'b'}, + {LLAMA_GRETYPE_CHAR_ALT, 'd'}, + {LLAMA_GRETYPE_CHAR_ALT, 'x'}, + {LLAMA_GRETYPE_CHAR_RNG_UPPER, 'z'}, + {LLAMA_GRETYPE_ALT, 0}, + {LLAMA_GRETYPE_CHAR_NOT, '1'}, + {LLAMA_GRETYPE_CHAR_RNG_UPPER, '3'}, + {LLAMA_GRETYPE_END, 0}, + }); + + verify_parsing(R"""( + root ::= a+ + a ::= "a" + )""", { + {"a", 1}, + {"root", 0}, + {"root_2", 2}, + }, { + // root (index 0) + {LLAMA_GRETYPE_RULE_REF, /* a */ 1}, + {LLAMA_GRETYPE_RULE_REF, /* root_2 */ 2}, + {LLAMA_GRETYPE_END, 0}, + // a (index 1) + {LLAMA_GRETYPE_CHAR, 'a'}, + {LLAMA_GRETYPE_END, 0}, + // root_2 (index 2) + {LLAMA_GRETYPE_RULE_REF, /* a */ 1}, + {LLAMA_GRETYPE_RULE_REF, /* root_2 */ 2}, + {LLAMA_GRETYPE_ALT, 0}, + {LLAMA_GRETYPE_END, 0}, + }); + + verify_parsing(R"""( + root ::= "a"+ + )""", { + {"root", 0}, + {"root_1", 1}, + }, { + // root (index 0) + {LLAMA_GRETYPE_CHAR, 'a'}, + {LLAMA_GRETYPE_RULE_REF, /* root_1 */ 1}, + {LLAMA_GRETYPE_END, 0}, + // root_1 (index 1) + {LLAMA_GRETYPE_CHAR, 'a'}, + {LLAMA_GRETYPE_RULE_REF, /* root_1 */ 1}, + {LLAMA_GRETYPE_ALT, 0}, + {LLAMA_GRETYPE_END, 0}, + }); + + verify_parsing(R"""( + root ::= a? + a ::= "a" + )""", { + {"a", 1}, + {"root", 0}, + {"root_2", 2}, + }, { + // root (index 0) + {LLAMA_GRETYPE_RULE_REF, /* root_2 */ 2}, + {LLAMA_GRETYPE_END, 0}, + // a (index 1) + {LLAMA_GRETYPE_CHAR, 'a'}, + {LLAMA_GRETYPE_END, 0}, + // root_2 (index 2) + {LLAMA_GRETYPE_RULE_REF, /* a */ 1}, + {LLAMA_GRETYPE_ALT, 0}, + {LLAMA_GRETYPE_END, 0}, + }); + + verify_parsing(R"""( + root ::= "a"? + )""", { + {"root", 0}, + {"root_1", 1}, + }, { + // root (index 0) + {LLAMA_GRETYPE_RULE_REF, /* root_1 */ 1}, + {LLAMA_GRETYPE_END, 0}, + // root_1 (index 1) + {LLAMA_GRETYPE_CHAR, 'a'}, + {LLAMA_GRETYPE_ALT, 0}, + {LLAMA_GRETYPE_END, 0}, + }); + + verify_parsing(R"""( + root ::= a* + a ::= "a" + )""", { + {"a", 1}, + {"root", 0}, + {"root_2", 2}, + }, { + // root (index 0) + {LLAMA_GRETYPE_RULE_REF, /* root_2 */ 2}, + {LLAMA_GRETYPE_END, 0}, + // a (index 1) + {LLAMA_GRETYPE_CHAR, 'a'}, + {LLAMA_GRETYPE_END, 0}, + // root_2 (index 2) + {LLAMA_GRETYPE_RULE_REF, /* a */ 1}, + {LLAMA_GRETYPE_RULE_REF, /* root_2 */ 2}, + {LLAMA_GRETYPE_ALT, 0}, + {LLAMA_GRETYPE_END, 0}, + }); + + verify_parsing(R"""( + root ::= "a"* + )""", { + {"root", 0}, + {"root_1", 1}, + }, { + // root (index 0) + {LLAMA_GRETYPE_RULE_REF, /* root_1 */ 1}, + {LLAMA_GRETYPE_END, 0}, + // root_1 (index 1) + {LLAMA_GRETYPE_CHAR, 'a'}, + {LLAMA_GRETYPE_RULE_REF, /* root_1 */ 1}, + {LLAMA_GRETYPE_ALT, 0}, + {LLAMA_GRETYPE_END, 0}, + }); + + verify_parsing(R"""( + root ::= "a"{2} + )""", { + {"root", 0}, + }, { + // root (index 0) + {LLAMA_GRETYPE_CHAR, 'a'}, + {LLAMA_GRETYPE_CHAR, 'a'}, + {LLAMA_GRETYPE_END, 0}, + }); + + verify_parsing(R"""( + root ::= "a"{2,} + )""", { + {"root", 0}, + {"root_1", 1}, + }, { + // root (index 0) + {LLAMA_GRETYPE_CHAR, 'a'}, + {LLAMA_GRETYPE_CHAR, 'a'}, + {LLAMA_GRETYPE_RULE_REF, /* root_1 */ 1}, + {LLAMA_GRETYPE_END, 0}, + // root_1 (index 1) + {LLAMA_GRETYPE_CHAR, 'a'}, + {LLAMA_GRETYPE_RULE_REF, /* root_1 */ 1}, + {LLAMA_GRETYPE_ALT, 0}, + {LLAMA_GRETYPE_END, 0}, + }); + + verify_parsing(R"""( + root ::= "a"{ 4} + )""", { + {"root", 0}, + }, { + // root (index 0) + {LLAMA_GRETYPE_CHAR, 'a'}, + {LLAMA_GRETYPE_CHAR, 'a'}, + {LLAMA_GRETYPE_CHAR, 'a'}, + {LLAMA_GRETYPE_CHAR, 'a'}, + {LLAMA_GRETYPE_END, 0}, + }); + + verify_parsing(R"""( + root ::= "a"{2,4} + )""", { + {"root", 0}, + {"root_1", 1}, + {"root_2", 2}, + }, { + // root (index 0) + {LLAMA_GRETYPE_CHAR, 'a'}, + {LLAMA_GRETYPE_CHAR, 'a'}, + {LLAMA_GRETYPE_RULE_REF, /* root_2 */ 2}, + {LLAMA_GRETYPE_END, 0}, + // root_1 (index 1) + {LLAMA_GRETYPE_CHAR, 'a'}, + {LLAMA_GRETYPE_ALT, 0}, + {LLAMA_GRETYPE_END, 0}, + // root_2 (index 2) + {LLAMA_GRETYPE_CHAR, 'a'}, + {LLAMA_GRETYPE_RULE_REF, /* root_1 */ 1}, + {LLAMA_GRETYPE_ALT, 0}, + {LLAMA_GRETYPE_END, 0}, + }); + + verify_parsing(R"""( + root ::= (expr "=" term "\n")+ + expr ::= term ([-+*/] term)* + term ::= [0-9]+ + )""", { + {"expr", 2}, + {"expr_5", 5}, + {"expr_6", 6}, + {"root", 0}, + {"root_1", 1}, + {"root_4", 4}, + {"term", 3}, + {"term_7", 7}, + }, { + // root (index 0) + {LLAMA_GRETYPE_RULE_REF, /* root_1 */ 1}, + {LLAMA_GRETYPE_RULE_REF, /* root_4 */ 4}, + {LLAMA_GRETYPE_END, 0}, + // root_1 (index 1) + {LLAMA_GRETYPE_RULE_REF, /* expr */ 2}, + {LLAMA_GRETYPE_CHAR, '='}, + {LLAMA_GRETYPE_RULE_REF, /* term */ 3}, + {LLAMA_GRETYPE_CHAR, '\n'}, + {LLAMA_GRETYPE_END, 0}, + // expr (index 2) + {LLAMA_GRETYPE_RULE_REF, /* term */ 3}, + {LLAMA_GRETYPE_RULE_REF, /* expr_6 */ 6}, + {LLAMA_GRETYPE_END, 0}, + // term (index 3) + {LLAMA_GRETYPE_CHAR, '0'}, + {LLAMA_GRETYPE_CHAR_RNG_UPPER, '9'}, + {LLAMA_GRETYPE_RULE_REF, /* term_7 */ 7}, + {LLAMA_GRETYPE_END, 0}, + // root_4 (index 4) + {LLAMA_GRETYPE_RULE_REF, /* root_1 */ 1}, + {LLAMA_GRETYPE_RULE_REF, /* root_4 */ 4}, + {LLAMA_GRETYPE_ALT, 0}, + {LLAMA_GRETYPE_END, 0}, + // expr_5 (index 5) + {LLAMA_GRETYPE_CHAR, '-'}, + {LLAMA_GRETYPE_CHAR_ALT, '+'}, + {LLAMA_GRETYPE_CHAR_ALT, '*'}, + {LLAMA_GRETYPE_CHAR_ALT, '/'}, + {LLAMA_GRETYPE_RULE_REF, /* term */ 3}, + {LLAMA_GRETYPE_END, 0}, + // expr_6 (index 6) + {LLAMA_GRETYPE_RULE_REF, /* expr_5 */ 5}, + {LLAMA_GRETYPE_RULE_REF, /* expr_6 */ 6}, + {LLAMA_GRETYPE_ALT, 0}, + {LLAMA_GRETYPE_END, 0}, + // term_7 (index 7) + {LLAMA_GRETYPE_CHAR, '0'}, + {LLAMA_GRETYPE_CHAR_RNG_UPPER, '9'}, + {LLAMA_GRETYPE_RULE_REF, /* term_7 */ 7}, + {LLAMA_GRETYPE_ALT, 0}, + {LLAMA_GRETYPE_END, 0}, + }); + + verify_parsing(R"""( + root ::= (expr "=" ws term "\n")+ + expr ::= term ([-+*/] term)* + term ::= ident | num | "(" ws expr ")" ws + ident ::= [a-z] [a-z0-9_]* ws + num ::= [0-9]+ ws + ws ::= [ \t\n]* + )""", { {"expr", 2}, {"expr_6", 6}, {"expr_7", 7}, @@ -132,119 +428,88 @@ term ::= [0-9]+)"""; {"term", 4}, {"ws", 3}, {"ws_12", 12}, - }; - - index = 0; - for (auto it = parsed_grammar.symbol_ids.begin(); it != parsed_grammar.symbol_ids.end(); ++it) - { - std::string key = it->first; - uint32_t value = it->second; - std::pair expected_pair = expected[index]; - - // pretty print error message before asserting - if (expected_pair.first != key || expected_pair.second != value) - { - fprintf(stderr, "expected_pair: %s, %u\n", expected_pair.first.c_str(), expected_pair.second); - fprintf(stderr, "actual_pair: %s, %u\n", key.c_str(), value); - fprintf(stderr, "expected_pair != actual_pair\n"); - } - - assert(expected_pair.first == key && expected_pair.second == value); - - index++; - } - expected_rules = { - {LLAMA_GRETYPE_RULE_REF, 5}, + }, { + // root (index 0) + {LLAMA_GRETYPE_RULE_REF, /* root_1 */ 1}, + {LLAMA_GRETYPE_RULE_REF, /* root_5 */ 5}, {LLAMA_GRETYPE_END, 0}, - {LLAMA_GRETYPE_RULE_REF, 2}, - {LLAMA_GRETYPE_CHAR, 61}, - {LLAMA_GRETYPE_RULE_REF, 3}, - {LLAMA_GRETYPE_RULE_REF, 4}, - {LLAMA_GRETYPE_CHAR, 10}, + // root_1 (index 1) + {LLAMA_GRETYPE_RULE_REF, /* expr */ 2}, + {LLAMA_GRETYPE_CHAR, '='}, + {LLAMA_GRETYPE_RULE_REF, /* ws */ 3}, + {LLAMA_GRETYPE_RULE_REF, /* term */ 4}, + {LLAMA_GRETYPE_CHAR, '\n'}, {LLAMA_GRETYPE_END, 0}, - {LLAMA_GRETYPE_RULE_REF, 4}, - {LLAMA_GRETYPE_RULE_REF, 7}, + // expr (index 2) + {LLAMA_GRETYPE_RULE_REF, /* term */ 4}, + {LLAMA_GRETYPE_RULE_REF, /* expr_7 */ 7}, {LLAMA_GRETYPE_END, 0}, - {LLAMA_GRETYPE_RULE_REF, 12}, + // ws (index 3) + {LLAMA_GRETYPE_RULE_REF, /* ws_12 */ 12}, {LLAMA_GRETYPE_END, 0}, - {LLAMA_GRETYPE_RULE_REF, 8}, + // term (index 4) + {LLAMA_GRETYPE_RULE_REF, /* ident */ 8}, {LLAMA_GRETYPE_ALT, 0}, - {LLAMA_GRETYPE_RULE_REF, 9}, + {LLAMA_GRETYPE_RULE_REF, /* num */ 9}, {LLAMA_GRETYPE_ALT, 0}, - {LLAMA_GRETYPE_CHAR, 40}, - {LLAMA_GRETYPE_RULE_REF, 3}, - {LLAMA_GRETYPE_RULE_REF, 2}, - {LLAMA_GRETYPE_CHAR, 41}, - {LLAMA_GRETYPE_RULE_REF, 3}, + {LLAMA_GRETYPE_CHAR, '('}, + {LLAMA_GRETYPE_RULE_REF, /* ws */ 3}, + {LLAMA_GRETYPE_RULE_REF, /* expr */ 2}, + {LLAMA_GRETYPE_CHAR, ')'}, + {LLAMA_GRETYPE_RULE_REF, /* ws */ 3}, {LLAMA_GRETYPE_END, 0}, - {LLAMA_GRETYPE_RULE_REF, 1}, - {LLAMA_GRETYPE_RULE_REF, 5}, - {LLAMA_GRETYPE_ALT, 0}, - {LLAMA_GRETYPE_RULE_REF, 1}, - {LLAMA_GRETYPE_END, 0}, - {LLAMA_GRETYPE_CHAR, 45}, - {LLAMA_GRETYPE_CHAR_ALT, 43}, - {LLAMA_GRETYPE_CHAR_ALT, 42}, - {LLAMA_GRETYPE_CHAR_ALT, 47}, - {LLAMA_GRETYPE_RULE_REF, 4}, - {LLAMA_GRETYPE_END, 0}, - {LLAMA_GRETYPE_RULE_REF, 6}, - {LLAMA_GRETYPE_RULE_REF, 7}, + // root_5 (index 5) + {LLAMA_GRETYPE_RULE_REF, /* root_1 */ 1}, + {LLAMA_GRETYPE_RULE_REF, /* root_5 */ 5}, {LLAMA_GRETYPE_ALT, 0}, {LLAMA_GRETYPE_END, 0}, - {LLAMA_GRETYPE_CHAR, 97}, - {LLAMA_GRETYPE_CHAR_RNG_UPPER, 122}, - {LLAMA_GRETYPE_RULE_REF, 10}, - {LLAMA_GRETYPE_RULE_REF, 3}, + // expr_6 (index 6) + {LLAMA_GRETYPE_CHAR, '-'}, + {LLAMA_GRETYPE_CHAR_ALT, '+'}, + {LLAMA_GRETYPE_CHAR_ALT, '*'}, + {LLAMA_GRETYPE_CHAR_ALT, '/'}, + {LLAMA_GRETYPE_RULE_REF, /* term */ 4}, {LLAMA_GRETYPE_END, 0}, - {LLAMA_GRETYPE_RULE_REF, 11}, - {LLAMA_GRETYPE_RULE_REF, 3}, - {LLAMA_GRETYPE_END, 0}, - {LLAMA_GRETYPE_CHAR, 97}, - {LLAMA_GRETYPE_CHAR_RNG_UPPER, 122}, - {LLAMA_GRETYPE_CHAR_ALT, 48}, - {LLAMA_GRETYPE_CHAR_RNG_UPPER, 57}, - {LLAMA_GRETYPE_CHAR_ALT, 95}, - {LLAMA_GRETYPE_RULE_REF, 10}, + // expr_7 (index 7) + {LLAMA_GRETYPE_RULE_REF, /* expr_6 */ 6}, + {LLAMA_GRETYPE_RULE_REF, /* expr_7 */ 7}, {LLAMA_GRETYPE_ALT, 0}, {LLAMA_GRETYPE_END, 0}, - {LLAMA_GRETYPE_CHAR, 48}, - {LLAMA_GRETYPE_CHAR_RNG_UPPER, 57}, - {LLAMA_GRETYPE_RULE_REF, 11}, - {LLAMA_GRETYPE_ALT, 0}, - {LLAMA_GRETYPE_CHAR, 48}, - {LLAMA_GRETYPE_CHAR_RNG_UPPER, 57}, + // ident (index 8) + {LLAMA_GRETYPE_CHAR, 'a'}, + {LLAMA_GRETYPE_CHAR_RNG_UPPER, 'z'}, + {LLAMA_GRETYPE_RULE_REF, /* ident_10 */ 10}, + {LLAMA_GRETYPE_RULE_REF, /* ws */ 3}, {LLAMA_GRETYPE_END, 0}, - {LLAMA_GRETYPE_CHAR, 32}, - {LLAMA_GRETYPE_CHAR_ALT, 9}, - {LLAMA_GRETYPE_CHAR_ALT, 10}, - {LLAMA_GRETYPE_RULE_REF, 12}, + // num (index 9) + {LLAMA_GRETYPE_CHAR, '0'}, + {LLAMA_GRETYPE_CHAR_RNG_UPPER, '9'}, + {LLAMA_GRETYPE_RULE_REF, /* num_11 */ 11}, + {LLAMA_GRETYPE_RULE_REF, /* ws */ 3}, + {LLAMA_GRETYPE_END, 0}, + // ident_10 (index 10) + {LLAMA_GRETYPE_CHAR, 'a'}, + {LLAMA_GRETYPE_CHAR_RNG_UPPER, 'z'}, + {LLAMA_GRETYPE_CHAR_ALT, '0'}, + {LLAMA_GRETYPE_CHAR_RNG_UPPER, '9'}, + {LLAMA_GRETYPE_CHAR_ALT, '_'}, + {LLAMA_GRETYPE_RULE_REF, /* ident_10 */ 10}, {LLAMA_GRETYPE_ALT, 0}, {LLAMA_GRETYPE_END, 0}, - }; - - index = 0; - for (auto rule : parsed_grammar.rules) - { - // compare rule to expected rule - for (uint32_t i = 0; i < rule.size(); i++) - { - llama_grammar_element element = rule[i]; - llama_grammar_element expected_element = expected_rules[index]; - - // pretty print error message before asserting - if (expected_element.type != element.type || expected_element.value != element.value) - { - fprintf(stderr, "index: %u\n", index); - fprintf(stderr, "expected_element: %d, %u\n", expected_element.type, expected_element.value); - fprintf(stderr, "actual_element: %d, %u\n", element.type, element.value); - fprintf(stderr, "expected_element != actual_element\n"); - } - - assert(expected_element.type == element.type && expected_element.value == element.value); - index++; - } - } + // num_11 (index 11) + {LLAMA_GRETYPE_CHAR, '0'}, + {LLAMA_GRETYPE_CHAR_RNG_UPPER, '9'}, + {LLAMA_GRETYPE_RULE_REF, /* num_11 */ 11}, + {LLAMA_GRETYPE_ALT, 0}, + {LLAMA_GRETYPE_END, 0}, + // ws_12 (index 12) + {LLAMA_GRETYPE_CHAR, ' '}, + {LLAMA_GRETYPE_CHAR_ALT, '\t'}, + {LLAMA_GRETYPE_CHAR_ALT, '\n'}, + {LLAMA_GRETYPE_RULE_REF, /* ws_12 */ 12}, + {LLAMA_GRETYPE_ALT, 0}, + {LLAMA_GRETYPE_END, 0}, + }); return 0; } diff --git a/tests/test-json-schema-to-grammar.cpp b/tests/test-json-schema-to-grammar.cpp index c5361b5b8..052c08073 100755 --- a/tests/test-json-schema-to-grammar.cpp +++ b/tests/test-json-schema-to-grammar.cpp @@ -105,9 +105,9 @@ static void test_all(const std::string & lang, std::function Date: Thu, 6 Jun 2024 09:17:54 -0300 Subject: [PATCH 05/21] README minor fixes (#7798) [no ci] derievatives --> derivatives --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 9d2a59d89..09e8cad31 100644 --- a/README.md +++ b/README.md @@ -598,7 +598,7 @@ Building the program with BLAS support may lead to some performance improvements To obtain the official LLaMA 2 weights please see the Obtaining and using the Facebook LLaMA 2 model section. There is also a large selection of pre-quantized `gguf` models available on Hugging Face. -Note: `convert.py` has been moved to `examples/convert-legacy-llama.py` and shouldn't be used for anything other than `Llama/Llama2/Mistral` models and their derievatives. +Note: `convert.py` has been moved to `examples/convert-legacy-llama.py` and shouldn't be used for anything other than `Llama/Llama2/Mistral` models and their derivatives. It does not support LLaMA 3, you can use `convert-hf-to-gguf.py` with LLaMA 3 downloaded from Hugging Face. ```bash From ad675e1c67a05b16e4e12abe30dbecfc808e7b7e Mon Sep 17 00:00:00 2001 From: Clint Herron Date: Thu, 6 Jun 2024 06:08:52 -0700 Subject: [PATCH 06/21] Added support for . (any character) token in grammar engine. (#6467) * Added support for . (any characer) token in grammar engine. * Add integration tests for any-character symbol. --- common/grammar-parser.cpp | 11 +++++++++++ llama.cpp | 12 ++++++++++-- llama.h | 3 +++ tests/test-grammar-integration.cpp | 28 ++++++++++++++++++++++++++++ 4 files changed, 52 insertions(+), 2 deletions(-) diff --git a/common/grammar-parser.cpp b/common/grammar-parser.cpp index 79d2b0354..a518b766d 100644 --- a/common/grammar-parser.cpp +++ b/common/grammar-parser.cpp @@ -266,6 +266,10 @@ namespace grammar_parser { throw std::runtime_error(std::string("expecting ')' at ") + pos); } pos = parse_space(pos + 1, is_nested); + } else if (*pos == '.') { // any char + last_sym_start = out_elements.size(); + out_elements.push_back({LLAMA_GRETYPE_CHAR_ANY, 0}); + pos = parse_space(pos + 1, is_nested); } else if (*pos == '*') { pos = parse_space(pos + 1, is_nested); handle_repetitions(0, -1); @@ -401,6 +405,7 @@ namespace grammar_parser { case LLAMA_GRETYPE_CHAR_NOT: return true; case LLAMA_GRETYPE_CHAR_ALT: return true; case LLAMA_GRETYPE_CHAR_RNG_UPPER: return true; + case LLAMA_GRETYPE_CHAR_ANY: return true; default: return false; } } @@ -415,6 +420,7 @@ namespace grammar_parser { case LLAMA_GRETYPE_CHAR_NOT: fprintf(file, "CHAR_NOT"); break; case LLAMA_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break; case LLAMA_GRETYPE_CHAR_ALT: fprintf(file, "CHAR_ALT"); break; + case LLAMA_GRETYPE_CHAR_ANY: fprintf(file, "CHAR_ANY"); break; } switch (elem.type) { case LLAMA_GRETYPE_END: @@ -426,6 +432,7 @@ namespace grammar_parser { case LLAMA_GRETYPE_CHAR_NOT: case LLAMA_GRETYPE_CHAR_RNG_UPPER: case LLAMA_GRETYPE_CHAR_ALT: + case LLAMA_GRETYPE_CHAR_ANY: fprintf(file, "(\""); print_grammar_char(file, elem.value); fprintf(file, "\") "); @@ -483,11 +490,15 @@ namespace grammar_parser { } print_grammar_char(file, elem.value); break; + case LLAMA_GRETYPE_CHAR_ANY: + fprintf(file, "."); + break; } if (is_char_element(elem)) { switch (rule[i + 1].type) { case LLAMA_GRETYPE_CHAR_ALT: case LLAMA_GRETYPE_CHAR_RNG_UPPER: + case LLAMA_GRETYPE_CHAR_ANY: break; default: fprintf(file, "] "); diff --git a/llama.cpp b/llama.cpp index cefb4d1d5..32264a008 100644 --- a/llama.cpp +++ b/llama.cpp @@ -13640,7 +13640,7 @@ static std::pair llama_grammar_match_char( const uint32_t chr) { bool found = false; - bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR; + bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR || pos->type == LLAMA_GRETYPE_CHAR_ANY; GGML_ASSERT(is_positive_char || pos->type == LLAMA_GRETYPE_CHAR_NOT); // NOLINT @@ -13649,6 +13649,10 @@ static std::pair llama_grammar_match_char( // inclusive range, e.g. [a-z] found = found || (pos->value <= chr && chr <= pos[1].value); pos += 2; + } else if (pos->type == LLAMA_GRETYPE_CHAR_ANY) { + // Any character matches "." + found = true; + pos += 1; } else { // exact char match, e.g. [a] or "a" found = found || pos->value == chr; @@ -13666,7 +13670,7 @@ static bool llama_grammar_match_partial_char( const llama_grammar_element * pos, const llama_partial_utf8 partial_utf8) { - bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR; + bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR || pos->type == LLAMA_GRETYPE_CHAR_ANY; GGML_ASSERT(is_positive_char || pos->type == LLAMA_GRETYPE_CHAR_NOT); uint32_t partial_value = partial_utf8.value; @@ -13696,6 +13700,9 @@ static bool llama_grammar_match_partial_char( return is_positive_char; } pos += 2; + } else if (pos->type == LLAMA_GRETYPE_CHAR_ANY) { + // Any character matches "." + return true; } else { // exact char match, e.g. [a] or "a" if (low <= pos->value && pos->value <= high) { @@ -13756,6 +13763,7 @@ static void llama_grammar_advance_stack( } case LLAMA_GRETYPE_CHAR: case LLAMA_GRETYPE_CHAR_NOT: + case LLAMA_GRETYPE_CHAR_ANY: if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) { // only add the stack if it's not a duplicate of one we already have new_stacks.emplace_back(stack); diff --git a/llama.h b/llama.h index 9dcd67bef..62908261f 100644 --- a/llama.h +++ b/llama.h @@ -365,6 +365,9 @@ extern "C" { // modifies a preceding LLAMA_GRETYPE_CHAR or // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA]) LLAMA_GRETYPE_CHAR_ALT = 6, + + // any character (.) + LLAMA_GRETYPE_CHAR_ANY = 7, }; typedef struct llama_grammar_element { diff --git a/tests/test-grammar-integration.cpp b/tests/test-grammar-integration.cpp index 9bdab05af..8787fb1ec 100644 --- a/tests/test-grammar-integration.cpp +++ b/tests/test-grammar-integration.cpp @@ -205,6 +205,33 @@ static void test_complex_grammar() { ); } +static void test_special_chars() { + // A collection of tests to exercise special characters such as "." + test_grammar( + "special characters", + // Grammar + R"""( + root ::= ... "abc" ... + )""", + // Passing strings + { + "abcabcabc", + "aaaabcccc", + // NOTE: Also ensures that multi-byte characters still count as a single character + "🔵🟠✅abc❌🟠🔵" + }, + // Failing strings + { + "aaabcccc", + "aaaaabcccc", + "aaaabccc", + "aaaabccccc", + "🔵🟠✅❌abc❌✅🟠🔵" + "🔵🟠abc🟠🔵" + } + ); +} + static void test_quantifiers() { // A collection of tests to exercise * + and ? quantifiers @@ -445,6 +472,7 @@ int main() { fprintf(stdout, "Running grammar integration tests...\n"); test_simple_grammar(); test_complex_grammar(); + test_special_chars(); test_quantifiers(); test_failure_missing_root(); test_failure_missing_reference(); From f83351f9a62a6262f1fc3d08f320033089cddfb5 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 6 Jun 2024 16:30:58 +0300 Subject: [PATCH 07/21] imatrix : migrate to gpt_params (#7771) * imatrix : migrate to gpt_params ggml-ci * imatrix : add --save-frequency cli arg * common : fix --no-ppl --- common/common.cpp | 75 ++++++++++- common/common.h | 99 +++++++------- examples/imatrix/README.md | 11 +- examples/imatrix/imatrix.cpp | 241 +++++++++++------------------------ examples/server/server.cpp | 2 +- 5 files changed, 213 insertions(+), 215 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index c8df9a4ce..601bd2164 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -273,6 +273,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { } } catch (const std::invalid_argument & ex) { fprintf(stderr, "%s\n", ex.what()); + params = params_org; return false; } @@ -408,6 +409,20 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa } return true; } + if (arg == "--in-file") { + if (++i >= argc) { + invalid_param = true; + return true; + } + std::ifstream file(argv[i]); + if (!file) { + fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); + invalid_param = true; + return true; + } + params.in_files.push_back(argv[i]); + return true; + } if (arg == "-n" || arg == "--predict" || arg == "--n-predict") { if (++i >= argc) { invalid_param = true; @@ -1081,7 +1096,15 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "-v" || arg == "--verbose") { - params.verbose = true; + params.verbosity = 1; + return true; + } + if (arg == "--verbosity") { + if (++i >= argc) { + invalid_param = true; + return true; + } + params.verbosity = std::stoi(argv[i]); return true; } if (arg == "--verbose-prompt") { @@ -1537,6 +1560,46 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.i_pos = std::stoi(argv[i]); return true; } + if (arg == "-o" || arg == "--output" || arg == "--output-file") { + if (++i >= argc) { + invalid_param = true; + return true; + } + params.out_file = argv[i]; + return true; + } + if (arg == "-ofreq" || arg == "--output-frequency") { + if (++i >= argc) { + invalid_param = true; + return true; + } + params.n_out_freq = std::stoi(argv[i]); + return true; + } + if (arg == "--save-frequency") { + if (++i >= argc) { + invalid_param = true; + return true; + } + params.n_save_freq = std::stoi(argv[i]); + return true; + } + if (arg == "--process-output") { + params.process_output = true; + return true; + } + if (arg == "--no-ppl") { + params.compute_ppl = false; + return true; + } + if (arg == "--chunk" || arg == "--from-chunk") { + if (++i >= argc) { + invalid_param = true; + return true; + } + params.i_chunk = std::stoi(argv[i]); + return true; + } #ifndef LOG_DISABLE_LOGS // Parse args for logging parameters if (log_param_single_parse(argv[i])) { @@ -1612,6 +1675,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "*", "-h, --help, --usage", "print usage and exit" }); options.push_back({ "*", " --version", "show version and build info" }); options.push_back({ "*", "-v, --verbose", "print verbose information" }); + options.push_back({ "*", " --verbosity N", "set specific verbosity level (default: %d)", params.verbosity }); options.push_back({ "*", " --verbose-prompt", "print a verbose prompt before generation (default: %s)", params.verbose_prompt ? "true" : "false" }); options.push_back({ "*", " --no-display-prompt", "don't print prompt at generation (default: %s)", !params.display_prompt ? "true" : "false" }); options.push_back({ "*", "-co, --color", "colorise output to distinguish prompt and user input from generations (default: %s)", params.use_color ? "true" : "false" }); @@ -1637,6 +1701,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "*", "-fa, --flash-attn", "enable Flash Attention (default: %s)", params.flash_attn ? "enabled" : "disabled" }); options.push_back({ "*", "-p, --prompt PROMPT", "prompt to start generation with (default: '%s')", params.prompt.c_str() }); options.push_back({ "*", "-f, --file FNAME", "a file containing the prompt (default: none)" }); + options.push_back({ "*", " --in-file FNAME", "an input file (repeat to specify multiple files)" }); options.push_back({ "*", "-bf, --binary-file FNAME", "binary file containing the prompt (default: none)" }); options.push_back({ "*", "-e, --escape", "process escapes sequences (\\n, \\r, \\t, \\', \\\", \\\\) (default: %s)", params.escape ? "true" : "false" }); options.push_back({ "*", " --no-escape", "do not process escape sequences" }); @@ -1804,6 +1869,14 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "passkey", " --junk N", "number of times to repeat the junk text (default: %d)", params.n_junk }); options.push_back({ "passkey", " --pos N", "position of the passkey in the junk text (default: %d)", params.i_pos }); + options.push_back({ "imatrix" }); + options.push_back({ "imatrix", "-o, --output FNAME", "output file (default: '%s')", params.out_file.c_str() }); + options.push_back({ "imatrix", " --output-frequency N", "output the imatrix every N iterations (default: %d)", params.n_out_freq }); + options.push_back({ "imatrix", " --save-frequency N", "save an imatrix copy every N iterations (default: %d)", params.n_save_freq }); + options.push_back({ "imatrix", " --process-output", "collect data for the output tensor (default: %s)", params.process_output ? "true" : "false" }); + options.push_back({ "imatrix", " --no-ppl", "do not compute perplexity (default: %s)", params.compute_ppl ? "true" : "false" }); + options.push_back({ "imatrix", " --chunk N", "start processing the input from chunk N (default: %d)", params.i_chunk }); + options.push_back({ "bench" }); options.push_back({ "bench", "-pps", "is the prompt shared across parallel sequences (default: %s)", params.is_pp_shared ? "true" : "false" }); options.push_back({ "bench", "-npp n0,n1,...", "number of prompt tokens" }); diff --git a/common/common.h b/common/common.h index e0a08a61b..de6238e27 100644 --- a/common/common.h +++ b/common/common.h @@ -56,43 +56,42 @@ struct gpt_params { uint32_t seed = LLAMA_DEFAULT_SEED; // RNG seed int32_t n_threads = cpu_get_num_math(); - int32_t n_threads_draft = -1; - int32_t n_threads_batch = -1; // number of threads to use for batch processing (-1 = use n_threads) - int32_t n_threads_batch_draft = -1; - int32_t n_predict = -1; // new tokens to predict - int32_t n_ctx = 0; // context size - int32_t n_batch = 2048; // logical batch size for prompt processing (must be >=32 to use BLAS) - int32_t n_ubatch = 512; // physical batch size for prompt processing (must be >=32 to use BLAS) - int32_t n_keep = 0; // number of tokens to keep from initial prompt - int32_t n_draft = 5; // number of tokens to draft during speculative decoding - int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited) - int32_t n_parallel = 1; // number of parallel sequences to decode - int32_t n_sequences = 1; // number of sequences to decode - float p_split = 0.1f; // speculative decoding split probability - int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default) - int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default) - llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs - int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors - float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs - int32_t n_beams = 0; // if non-zero then use beam search of given width. - int32_t grp_attn_n = 1; // group-attention factor - int32_t grp_attn_w = 512; // group-attention width - int32_t n_print = -1; // print token count every n tokens (-1 = disabled) - float rope_freq_base = 0.0f; // RoPE base frequency - float rope_freq_scale = 0.0f; // RoPE frequency scaling factor + int32_t n_threads_draft = -1; + int32_t n_threads_batch = -1; // number of threads to use for batch processing (-1 = use n_threads) + int32_t n_threads_batch_draft = -1; + int32_t n_predict = -1; // new tokens to predict + int32_t n_ctx = 0; // context size + int32_t n_batch = 2048; // logical batch size for prompt processing (must be >=32 to use BLAS) + int32_t n_ubatch = 512; // physical batch size for prompt processing (must be >=32 to use BLAS) + int32_t n_keep = 0; // number of tokens to keep from initial prompt + int32_t n_draft = 5; // number of tokens to draft during speculative decoding + int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited) + int32_t n_parallel = 1; // number of parallel sequences to decode + int32_t n_sequences = 1; // number of sequences to decode + float p_split = 0.1f; // speculative decoding split probability + int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default) + int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default) + int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors + float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs + int32_t n_beams = 0; // if non-zero then use beam search of given width. + int32_t grp_attn_n = 1; // group-attention factor + int32_t grp_attn_w = 512; // group-attention width + int32_t n_print = -1; // print token count every n tokens (-1 = disabled) + float rope_freq_base = 0.0f; // RoPE base frequency + float rope_freq_scale = 0.0f; // RoPE frequency scaling factor float yarn_ext_factor = -1.0f; // YaRN extrapolation mix factor - float yarn_attn_factor = 1.0f; // YaRN magnitude scaling factor + float yarn_attn_factor = 1.0f; // YaRN magnitude scaling factor float yarn_beta_fast = 32.0f; // YaRN low correction dim - float yarn_beta_slow = 1.0f; // YaRN high correction dim - int32_t yarn_orig_ctx = 0; // YaRN original context length + float yarn_beta_slow = 1.0f; // YaRN high correction dim + int32_t yarn_orig_ctx = 0; // YaRN original context length float defrag_thold = -1.0f; // KV cache defragmentation threshold - std::string rpc_servers = ""; // comma separated list of RPC servers ggml_backend_sched_eval_callback cb_eval = nullptr; void * cb_eval_user_data = nullptr; ggml_numa_strategy numa = GGML_NUMA_STRATEGY_DISABLED; + enum llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs enum llama_rope_scaling_type rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED; enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings @@ -114,7 +113,9 @@ struct gpt_params { std::string lookup_cache_static = ""; // path of static ngram cache file for lookup decoding std::string lookup_cache_dynamic = ""; // path of dynamic ngram cache file for lookup decoding std::string logits_file = ""; // file for saving *all* logits + std::string rpc_servers = ""; // comma separated list of RPC servers + std::vector in_files; // all input files std::vector antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts) std::vector kv_overrides; @@ -124,23 +125,24 @@ struct gpt_params { std::vector control_vectors; // control vector with user defined scale + int32_t verbosity = 0; int32_t control_vector_layer_start = -1; // layer range for control vector int32_t control_vector_layer_end = -1; // layer range for control vector - int32_t ppl_stride = 0; // stride for perplexity calculations. If left at 0, the pre-existing approach will be used. - int32_t ppl_output_type = 0; // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line - // (which is more convenient to use for plotting) - // - bool hellaswag = false; // compute HellaSwag score over random tasks from datafile supplied in prompt - size_t hellaswag_tasks = 400; // number of tasks to use when computing the HellaSwag score + int32_t ppl_stride = 0; // stride for perplexity calculations. If left at 0, the pre-existing approach will be used. + int32_t ppl_output_type = 0; // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line + // (which is more convenient to use for plotting) + // + bool hellaswag = false; // compute HellaSwag score over random tasks from datafile supplied in prompt + size_t hellaswag_tasks = 400; // number of tasks to use when computing the HellaSwag score - bool winogrande = false; // compute Winogrande score over random tasks from datafile supplied in prompt - size_t winogrande_tasks= 0; // number of tasks to use when computing the Winogrande score. If 0, all tasks will be computed + bool winogrande = false; // compute Winogrande score over random tasks from datafile supplied in prompt + size_t winogrande_tasks = 0; // number of tasks to use when computing the Winogrande score. If 0, all tasks will be computed - bool multiple_choice = false; // compute TruthfulQA score over random tasks from datafile supplied in prompt - size_t multiple_choice_tasks = 0; // number of tasks to use when computing the TruthfulQA score. If 0, all tasks will be computed + bool multiple_choice = false; // compute TruthfulQA score over random tasks from datafile supplied in prompt + size_t multiple_choice_tasks = 0; // number of tasks to use when computing the TruthfulQA score. If 0, all tasks will be computed - bool kl_divergence = false; // compute KL divergence + bool kl_divergence = false; // compute KL divergence bool usage = false; // print usage bool use_color = false; // use color to distinguish generations and inputs @@ -163,7 +165,6 @@ struct gpt_params { bool logits_all = false; // return logits for all tokens in the batch bool use_mmap = true; // use mmap for faster loads bool use_mlock = false; // use mlock to keep model in memory - bool verbose = false; bool verbose_prompt = false; // print prompt tokens before generation bool display_prompt = true; // print prompt before generation bool infill = false; // use infill mode @@ -180,10 +181,10 @@ struct gpt_params { std::vector image; // path to image file(s) // server params - int32_t port = 8080; - int32_t timeout_read = 600; - int32_t timeout_write = timeout_read; - int32_t n_threads_http = -1; + int32_t port = 8080; // server listens on this network port + int32_t timeout_read = 600; // http read timeout in seconds + int32_t timeout_write = timeout_read; // http write timeout in seconds + int32_t n_threads_http = -1; // number of threads to use for http server (-1 = use n_threads) std::string hostname = "127.0.0.1"; std::string public_path = ""; @@ -219,6 +220,16 @@ struct gpt_params { // passkey params int32_t n_junk = 250; // number of times to repeat the junk text int32_t i_pos = -1; // position of the passkey in the junk text + + // imatrix params + std::string out_file = "imatrix.dat"; // save the resulting imatrix to this file + + int32_t n_out_freq = 10; // output the imatrix every n_out_freq iterations + int32_t n_save_freq = 0; // save the imatrix every n_save_freq iterations + int32_t i_chunk = 0; // start processing from this chunk + + bool process_output = false; // collect data for the output tensor + bool compute_ppl = true; // whether to compute perplexity }; void gpt_params_handle_model_default(gpt_params & params); diff --git a/examples/imatrix/README.md b/examples/imatrix/README.md index 458c01b87..866ca9f56 100644 --- a/examples/imatrix/README.md +++ b/examples/imatrix/README.md @@ -6,16 +6,19 @@ More information is available here: https://github.com/ggerganov/llama.cpp/pull/ ## Usage ``` -./imatrix -m -f [-o ] [--verbosity ] - [-ofreq num_chunks] [-ow <0 or 1>] [other common params] +./imatrix \ + -m model.gguf -f some-text.txt [-o imatrix.dat] [--process-output] [--verbosity 1] \ + [--no-ppl] [--chunk 123] [--output-frequency 10] [--save-frequency 0] \ + [--in-file imatrix-prev-0.dat --in-file imatrix-prev-1.dat ...] ``` Here `-m` with a model name and `-f` with a file containing training data (such as e.g. `wiki.train.raw`) are mandatory. The parameters in square brackets are optional and have the following meaning: * `-o` (or `--output-file`) specifies the name of the file where the computed data will be stored. If missing `imatrix.dat` is used. * `--verbosity` specifies the verbosity level. If set to `0`, no output other than the perplexity of the processed chunks will be generated. If set to `1`, each time the results are saved a message is written to `stderr`. If `>=2`, a message is output each time data is collected for any tensor. Default verbosity level is `1`. -* `-ofreq` (or `--output-frequency`) specifies how often the so far computed result is saved to disk. Default is 10 (i.e., every 10 chunks) -* `-ow` (or `--output-weight`) specifies if data will be collected for the `output.weight` tensor. My experience is that it is better to not utilize the importance matrix when quantizing `output.weight`, so this is set to `false` by default. +* `--output-frequency` specifies how often the so far computed result is saved to disk. Default is 10 (i.e., every 10 chunks) +* `--save-frequency` specifies how often to save a copy of the imatrix in a separate file. Default is 0 (i.e., never) +* `--process-output` specifies if data will be collected for the `output.weight` tensor. My experience is that it is better to not utilize the importance matrix when quantizing `output.weight`, so this is set to `false` by default. For faster computation, make sure to use GPU offloading via the `-ngl` argument diff --git a/examples/imatrix/imatrix.cpp b/examples/imatrix/imatrix.cpp index e050c09d2..38420041c 100644 --- a/examples/imatrix/imatrix.cpp +++ b/examples/imatrix/imatrix.cpp @@ -17,39 +17,37 @@ #pragma warning(disable: 4244 4267) // possible loss of data #endif +static void print_usage(int argc, char ** argv, const gpt_params & params) { + gpt_params_print_usage(argc, argv, params); + + LOG_TEE("\nexample usage:\n"); + LOG_TEE("\n %s \\\n" + " -m model.gguf -f some-text.txt [-o imatrix.dat] [--process-output] [--verbosity 1] \\\n" + " [--no-ppl] [--chunk 123] [--output-frequency 10] [--save-frequency 0] \\\n" + " [--in-file imatrix-prev-0.dat --in-file imatrix-prev-1.dat ...]\n" , argv[0]); + LOG_TEE("\n"); +} + struct Stats { std::vector values; std::vector counts; int ncall = 0; }; -struct StatParams { - std::string dataset; - std::string ofile = "imatrix.dat"; - int n_output_frequency = 10; - int verbosity = 1; - int keep_every = 0; - bool collect_output_weight = false; -}; - class IMatrixCollector { public: IMatrixCollector() = default; - void set_parameters(StatParams&& params) { m_params = std::move(params); } + void set_params(gpt_params params) { m_params = std::move(params); } bool collect_imatrix(struct ggml_tensor * t, bool ask, void * user_data); - void save_imatrix() const; - bool load_imatrix(const char * file_name, bool add); - static bool load_imatrix(const char * file_name, std::unordered_map& imatrix); + void save_imatrix(int ncall = -1) const; + bool load_imatrix(const char * file_name); private: std::unordered_map m_stats; - StatParams m_params; + gpt_params m_params; std::mutex m_mutex; int m_last_call = 0; std::vector m_src1_data; std::vector m_ids; // the expert ids from ggml_mul_mat_id - // - void save_imatrix(const char * file_name, const char * dataset) const; - void keep_imatrix(int ncall) const; }; // remove any prefix and suffixes from the name @@ -85,7 +83,7 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * if (t->op != GGML_OP_MUL_MAT) return false; // why are small batches ignored (<16 tokens)? if (src1->ne[1] < 16 || src1->type != GGML_TYPE_F32) return false; - if (!(wname.substr(0, 4) == "blk." || (m_params.collect_output_weight && wname == "output.weight"))) return false; + if (!(wname.substr(0, 4) == "blk." || (m_params.process_output && wname == "output.weight"))) return false; return true; } @@ -158,16 +156,16 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * } if (e.ncall > m_last_call) { m_last_call = e.ncall; - if (m_last_call % m_params.n_output_frequency == 0) { + if (m_last_call % m_params.n_out_freq == 0) { save_imatrix(); } - if (m_params.keep_every > 0 && m_last_call%m_params.keep_every == 0) { - keep_imatrix(m_last_call); + if (m_params.n_save_freq > 0 && m_last_call%m_params.n_save_freq == 0) { + save_imatrix(m_last_call); } } } } else { - auto& e = m_stats[wname]; + auto & e = m_stats[wname]; if (e.values.empty()) { e.values.resize(src1->ne[0], 0); e.counts.resize(src1->ne[0], 0); @@ -189,11 +187,11 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * } if (e.ncall > m_last_call) { m_last_call = e.ncall; - if (m_last_call % m_params.n_output_frequency == 0) { + if (m_last_call % m_params.n_out_freq == 0) { save_imatrix(); } - if (m_params.keep_every > 0 && m_last_call%m_params.keep_every == 0) { - keep_imatrix(m_last_call); + if (m_params.n_save_freq > 0 && m_last_call%m_params.n_save_freq == 0) { + save_imatrix(m_last_call); } } } @@ -201,19 +199,17 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * return true; } -void IMatrixCollector::save_imatrix() const { - save_imatrix(m_params.ofile.empty() ? "imatrix.dat" : m_params.ofile.c_str(), m_params.dataset.c_str()); -} +void IMatrixCollector::save_imatrix(int ncall) const { + auto fname = m_params.out_file; + if (fname.empty()) { + fname = "imatrix.dat"; + } -void IMatrixCollector::keep_imatrix(int ncall) const { - auto file_name = m_params.ofile; - if (file_name.empty()) file_name = "imatrix.dat"; - file_name += ".at_"; - file_name += std::to_string(ncall); - save_imatrix(file_name.c_str(), m_params.dataset.c_str()); -} + if (ncall > 0) { + fname += ".at_"; + fname += std::to_string(ncall); + } -void IMatrixCollector::save_imatrix(const char * fname, const char * dataset) const { std::ofstream out(fname, std::ios::binary); int n_entries = m_stats.size(); out.write((const char *) &n_entries, sizeof(n_entries)); @@ -236,26 +232,28 @@ void IMatrixCollector::save_imatrix(const char * fname, const char * dataset) co // Write the number of call the matrix was computed with out.write((const char *) &m_last_call, sizeof(m_last_call)); - // Write the dataset name at the end of the file to later on specify it in quantize - int n_dataset = strlen(dataset); - out.write((const char *) &n_dataset, sizeof(n_dataset)); - out.write(dataset, n_dataset); + // Write the input filename at the end of the file to later on specify it in quantize + { + int len = m_params.prompt_file.size(); + out.write((const char *) &len, sizeof(len)); + out.write(m_params.prompt_file.c_str(), len); + } if (m_params.verbosity > 0) { - fprintf(stderr, "\n%s: stored collected data after %d chunks in %s\n", __func__, m_last_call, fname); + fprintf(stderr, "\n%s: stored collected data after %d chunks in %s\n", __func__, m_last_call, fname.c_str()); } } -bool IMatrixCollector::load_imatrix(const char * imatrix_file, std::unordered_map& imatrix_data) { - std::ifstream in(imatrix_file, std::ios::binary); +bool IMatrixCollector::load_imatrix(const char * fname) { + std::ifstream in(fname, std::ios::binary); if (!in) { - printf("%s: failed to open %s\n",__func__,imatrix_file); + printf("%s: failed to open %s\n",__func__, fname); return false; } int n_entries; in.read((char*)&n_entries, sizeof(n_entries)); if (in.fail() || n_entries < 1) { - printf("%s: no data in file %s\n", __func__, imatrix_file); + printf("%s: no data in file %s\n", __func__, fname); return false; } for (int i = 0; i < n_entries; ++i) { @@ -263,23 +261,22 @@ bool IMatrixCollector::load_imatrix(const char * imatrix_file, std::unordered_ma std::vector name_as_vec(len+1); in.read((char *)name_as_vec.data(), len); if (in.fail()) { - printf("%s: failed reading name for entry %d from %s\n",__func__,i+1,imatrix_file); + printf("%s: failed reading name for entry %d from %s\n",__func__,i+1, fname); return false; } name_as_vec[len] = 0; std::string name{name_as_vec.data()}; - auto& e = imatrix_data[std::move(name)]; + auto & e = m_stats[std::move(name)]; int ncall; in.read((char*)&ncall, sizeof(ncall)); int nval; in.read((char *)&nval, sizeof(nval)); if (in.fail() || nval < 1) { printf("%s: failed reading number of values for entry %d\n",__func__,i); - imatrix_data = {}; + m_stats = {}; return false; } - // When re-called from load_imatrix() with add set, this will already be created. if (e.values.empty()) { e.values.resize(nval, 0); e.counts.resize(nval, 0); @@ -289,7 +286,7 @@ bool IMatrixCollector::load_imatrix(const char * imatrix_file, std::unordered_ma in.read((char*)tmp.data(), nval*sizeof(float)); if (in.fail()) { printf("%s: failed reading data for entry %d\n",__func__,i); - imatrix_data = {}; + m_stats = {}; return false; } @@ -304,13 +301,6 @@ bool IMatrixCollector::load_imatrix(const char * imatrix_file, std::unordered_ma return true; } -bool IMatrixCollector::load_imatrix(const char * file_name, bool add) { - if (!add) { - m_stats.clear(); - } - return load_imatrix(file_name, m_stats); -} - static IMatrixCollector g_collector; static bool ik_collect_imatrix(struct ggml_tensor * t, bool ask, void * user_data) { @@ -324,7 +314,7 @@ struct results_log_softmax { float prob; }; -static std::vector softmax(const std::vector& logits) { +static std::vector softmax(const std::vector & logits) { std::vector probs(logits.size()); float max_logit = logits[0]; for (float v : logits) { @@ -358,8 +348,7 @@ static results_log_softmax log_softmax(int n_vocab, const float * logits, int to static void process_logits( int n_vocab, const float * logits, const int * tokens, int n_token, std::vector & workers, - double & nll, double & nll2, float * logit_history, float * prob_history -) { + double & nll, double & nll2, float * logit_history, float * prob_history) { std::mutex mutex; int counter = 0; auto compute = [&mutex, &counter, &nll, &nll2, logit_history, prob_history, n_vocab, logits, tokens, n_token] () { @@ -391,8 +380,7 @@ static void process_logits( } } -static bool compute_imatrix(llama_context * ctx, const gpt_params & params, bool compute_ppl, int from_chunk) { - +static bool compute_imatrix(llama_context * ctx, const gpt_params & params) { const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx)); GGML_ASSERT(llama_add_eos_token(llama_get_model(ctx)) != 1); const int n_ctx = llama_n_ctx(ctx); @@ -405,13 +393,13 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params, bool auto tim2 = std::chrono::high_resolution_clock::now(); fprintf(stderr, "%s: tokenization took %g ms\n",__func__,1e-3*std::chrono::duration_cast(tim2-tim1).count()); - if (from_chunk > 0) { - if (size_t((from_chunk + 2)*n_ctx) >= tokens.size()) { - fprintf(stderr, "%s: there will be not enough tokens left after removing %d chunks\n", __func__, from_chunk); + if (params.i_chunk > 0) { + if (size_t((params.i_chunk + 2)*n_ctx) >= tokens.size()) { + fprintf(stderr, "%s: there will be not enough tokens left after removing %d chunks\n", __func__, params.i_chunk); return false; } - fprintf(stderr, "%s: removing initial %d chunks (%d tokens)\n", __func__, from_chunk, from_chunk*n_ctx); - tokens.erase(tokens.begin(), tokens.begin() + from_chunk*n_ctx); + fprintf(stderr, "%s: removing initial %d chunks (%d tokens)\n", __func__, params.i_chunk, params.i_chunk*n_ctx); + tokens.erase(tokens.begin(), tokens.begin() + params.i_chunk*n_ctx); } if (int(tokens.size()) < 2*n_ctx) { @@ -424,7 +412,7 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params, bool std::vector logit_history; std::vector prob_history; - if (compute_ppl) { + if (params.compute_ppl) { logit_history.resize(tokens.size()); prob_history.resize(tokens.size()); } @@ -446,7 +434,7 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params, bool const int num_batches = (n_ctx + n_batch - 1) / n_batch; std::vector logits; - if (compute_ppl && num_batches > 1) { + if (params.compute_ppl && num_batches > 1) { logits.reserve((size_t)n_ctx * n_vocab); } @@ -482,7 +470,7 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params, bool // restore the original token in case it was set to BOS tokens[batch_start] = token_org; - if (compute_ppl && num_batches > 1) { + if (params.compute_ppl && num_batches > 1) { const auto * batch_logits = llama_get_logits(ctx); logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab); } @@ -501,7 +489,7 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params, bool fprintf(stderr, "%.2f minutes\n", total_seconds / 60.0); } - if (compute_ppl) { + if (params.compute_ppl) { const int first = n_ctx/2; const auto all_logits = num_batches > 1 ? logits.data() : llama_get_logits(ctx); process_logits(n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first, @@ -516,7 +504,7 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params, bool } printf("\n"); - if (compute_ppl) { + if (params.compute_ppl) { nll2 /= count; nll /= count; const double ppl = exp(nll); @@ -533,109 +521,32 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params, bool } int main(int argc, char ** argv) { - StatParams sparams; - std::string prev_result_file; - std::string combine_files; - bool compute_ppl = true; - int from_chunk = 0; - std::vector args; - args.push_back(argv[0]); - int iarg = 1; - for (; iarg < argc-1; ++iarg) { - std::string arg{argv[iarg]}; - if (arg == "-o" || arg == "--output-file") { - sparams.ofile = argv[++iarg]; - } - else if (arg == "-ofreq" || arg == "--output-frequency") { - sparams.n_output_frequency = std::stoi(argv[++iarg]); - } - else if (arg == "-ow" || arg == "--output-weight") { - sparams.collect_output_weight = std::stoi(argv[++iarg]); - } - else if (arg == "--verbosity") { - sparams.verbosity = std::stoi(argv[++iarg]); - } else if (arg == "--no-ppl") { - compute_ppl = false; - } else if (arg == "--keep-imatrix") { - sparams.keep_every = std::stoi(argv[++iarg]); - } else if (arg == "--continue-from") { - prev_result_file = argv[++iarg]; - } else if (arg == "--combine") { - combine_files = argv[++iarg]; - } - else if (arg == "--from-chunk") { - from_chunk = std::stoi(argv[++iarg]); - } else { - args.push_back(argv[iarg]); - } - } - if (iarg < argc) { - std::string arg{argv[iarg]}; - if (arg == "--no-ppl") { - compute_ppl = false; - } else { - args.push_back(argv[iarg]); - } - } - gpt_params params; - params.n_batch = 512; + + params.n_ctx = 512; + params.logits_all = true; + params.verbosity = 1; if (!gpt_params_parse(argc, argv, params)) { - gpt_params_print_usage(argc, argv, params); + print_usage(argc, argv, params); return 1; } - params.logits_all = true; params.n_batch = std::min(params.n_batch, params.n_ctx); - print_build_info(); + g_collector.set_params(params); - if (params.seed == LLAMA_DEFAULT_SEED) { - params.seed = time(NULL); - } - - fprintf(stderr, "%s: seed = %u\n", __func__, params.seed); - - std::mt19937 rng(params.seed); - - sparams.dataset = params.prompt_file; - g_collector.set_parameters(std::move(sparams)); - - if (!combine_files.empty()) { - std::vector files; - size_t pos = 0; - while (true) { - auto new_pos = combine_files.find(',', pos); - if (new_pos != std::string::npos) { - files.emplace_back(combine_files.substr(pos, new_pos - pos)); - pos = new_pos + 1; - } else { - files.emplace_back(combine_files.substr(pos)); - break; - } - } - if (files.size() < 2) { - fprintf(stderr, "You must provide at least two comma separated files to use --combine\n"); + for (const auto & in_file : params.in_files) { + printf("%s : loading imatrix from '%s'\n", __func__, in_file.c_str()); + if (!g_collector.load_imatrix(in_file.c_str())) { + fprintf(stderr, "%s : failed to load %s\n", __func__, in_file.c_str()); return 1; } - printf("Combining the following %d files\n", int(files.size())); - for (auto& file : files) { - printf(" %s\n", file.c_str()); - if (!g_collector.load_imatrix(file.c_str(), true)) { - fprintf(stderr, "Failed to load %s\n", file.c_str()); - return 1; - } - } + } + + if (params.in_files.size() > 1) { + printf("%s : saving combined imatrix to '%s'\n", __func__, params.out_file.c_str()); g_collector.save_imatrix(); - return 0; - } - - if (!prev_result_file.empty()) { - if (!g_collector.load_imatrix(prev_result_file.c_str(), false)) { - fprintf(stderr, "=============== Failed to load %s\n", prev_result_file.c_str()); - return 1; - } } llama_backend_init(); @@ -650,6 +561,7 @@ int main(int argc, char ** argv) { // init llama_model * model; llama_context * ctx; + std::tie(model, ctx) = llama_init_from_gpt_params(params); if (model == nullptr || ctx == nullptr) { fprintf(stderr, "%s : failed to init\n", __func__); @@ -668,8 +580,7 @@ int main(int argc, char ** argv) { fprintf(stderr, "%s\n", gpt_params_get_system_info(params).c_str()); } - bool OK = compute_imatrix(ctx, params, compute_ppl, from_chunk); - if (!OK) { + if (!compute_imatrix(ctx, params)) { return 1; } diff --git a/examples/server/server.cpp b/examples/server/server.cpp index d581cad95..74da81dad 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2360,7 +2360,7 @@ int main(int argc, char ** argv) { // TODO: not great to use extern vars server_log_json = params.log_json; - server_verbose = params.verbose; + server_verbose = params.verbosity > 0; // struct that contains llama context and inference server_context ctx_server; From ee459f40f65810a810151b24eba5b8bd174ceffe Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 6 Jun 2024 19:19:59 +0300 Subject: [PATCH 08/21] server : fix --threads-http arg (#7801) --- common/common.cpp | 9 +++++++++ common/common.h | 2 +- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/common/common.cpp b/common/common.cpp index 601bd2164..cdcb352b5 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1414,6 +1414,14 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.timeout_write = std::stoi(argv[i]); return true; } + if (arg == "--threads-http") { + if (++i >= argc) { + invalid_param = true; + return true; + } + params.n_threads_http = std::stoi(argv[i]); + return true; + } if (arg == "-spf" || arg == "--system-prompt-file") { if (++i >= argc) { invalid_param = true; @@ -1893,6 +1901,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "server", " --ssl-key-file FNAME", "path to file a PEM-encoded SSL private key" }); options.push_back({ "server", " --ssl-cert-file FNAME", "path to file a PEM-encoded SSL certificate" }); options.push_back({ "server", " --timeout N", "server read/write timeout in seconds (default: %d)", params.timeout_read }); + options.push_back({ "server", " --threads-http N", "number of threads used to process HTTP requests (default: %d)", params.n_threads_http }); options.push_back({ "server", " --system-prompt-file FNAME", "set a file to load a system prompt (initial prompt of all slots), this is useful for chat applications" }); options.push_back({ "server", " --log-format {text,json}", diff --git a/common/common.h b/common/common.h index de6238e27..35f5311e1 100644 --- a/common/common.h +++ b/common/common.h @@ -184,7 +184,7 @@ struct gpt_params { int32_t port = 8080; // server listens on this network port int32_t timeout_read = 600; // http read timeout in seconds int32_t timeout_write = timeout_read; // http write timeout in seconds - int32_t n_threads_http = -1; // number of threads to use for http server (-1 = use n_threads) + int32_t n_threads_http = -1; // number of threads to process HTTP requests std::string hostname = "127.0.0.1"; std::string public_path = ""; From c9ee7118d5644dd3df70ea6878b36a9761616aab Mon Sep 17 00:00:00 2001 From: slaren Date: Fri, 7 Jun 2024 08:01:29 +0200 Subject: [PATCH 09/21] check for nans in imatrix and quantize (#7807) * imatrix : detect nan/inf values * quantize : check imatrix for nan/inf values --- examples/imatrix/imatrix.cpp | 8 ++++++++ llama.cpp | 8 ++++++++ 2 files changed, 16 insertions(+) diff --git a/examples/imatrix/imatrix.cpp b/examples/imatrix/imatrix.cpp index 38420041c..e18f49563 100644 --- a/examples/imatrix/imatrix.cpp +++ b/examples/imatrix/imatrix.cpp @@ -151,6 +151,10 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * for (int j = 0; j < (int)src1->ne[0]; ++j) { e.values[e_start + j] += x[j]*x[j]; e.counts[e_start + j]++; + if (!std::isfinite(e.values[e_start + j])) { + fprintf(stderr, "%f detected in %s\n", e.values[e_start + j], wname.c_str()); + exit(1); + } } } } @@ -183,6 +187,10 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * for (int j = 0; j < (int)src1->ne[0]; ++j) { e.values[j] += x[j]*x[j]; e.counts[j]++; + if (!std::isfinite(e.values[j])) { + fprintf(stderr, "%f detected in %s\n", e.values[j], wname.c_str()); + exit(1); + } } } if (e.ncall > m_last_call) { diff --git a/llama.cpp b/llama.cpp index 32264a008..8b675ea99 100644 --- a/llama.cpp +++ b/llama.cpp @@ -15237,6 +15237,14 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s if (imatrix_data) { LLAMA_LOG_INFO("================================ Have weights data with %d entries\n",int(imatrix_data->size())); qs.has_imatrix = true; + // check imatrix for nans or infs + for (const auto & kv : *imatrix_data) { + for (float f : kv.second) { + if (!std::isfinite(f)) { + throw std::runtime_error(format("imatrix contains non-finite value %f\n", f)); + } + } + } } } From d5c938cd7716b9a2ace49a43a469dfbffcff4d28 Mon Sep 17 00:00:00 2001 From: pengxin99 Date: Fri, 7 Jun 2024 14:28:26 +0800 Subject: [PATCH 10/21] [SYCL] fix softmax r2r result wrong issue (#7811) --- ggml-sycl.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/ggml-sycl.cpp b/ggml-sycl.cpp index 3ff76474d..0a645b2e1 100644 --- a/ggml-sycl.cpp +++ b/ggml-sycl.cpp @@ -9108,6 +9108,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const // find the sum of exps in the block tmp = warp_reduce_sum(tmp, item_ct1); if (block_size > WARP_SIZE) { + item_ct1.barrier(sycl::access::fence_space::local_space); if (warp_id == 0) { buf[lane_id] = 0.f; } From a5cabd76491f07494c5b8267f921c73f5e2bbfb4 Mon Sep 17 00:00:00 2001 From: woodx <124784234+woodx9@users.noreply.github.com> Date: Fri, 7 Jun 2024 15:09:45 +0800 Subject: [PATCH 11/21] server : do not get prompt in infill mode (#7286) * avoid to get prompt in infill mode and embedding mode * remove embedding mode * refactor format --------- Co-authored-by: wudexiang --- examples/server/server.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 74da81dad..528220607 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -888,7 +888,7 @@ struct server_context { slot.params.input_suffix = json_value(data, "input_suffix", default_params.input_suffix); // get prompt - { + if (!task.infill) { const auto & prompt = data.find("prompt"); if (prompt == data.end()) { send_error(task, "Either \"prompt\" or \"messages\" must be provided", ERROR_TYPE_INVALID_REQUEST); From 7027b27d765db95d4ac6b569d976e387a8715881 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Fri, 7 Jun 2024 11:15:49 +0200 Subject: [PATCH 12/21] server: update cache_prompt documentation [no ci] (#7745) --- examples/server/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/server/README.md b/examples/server/README.md index 0c3db8c84..ccbdcdbdb 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -279,7 +279,7 @@ node index.js `id_slot`: Assign the completion task to an specific slot. If is -1 the task will be assigned to a Idle slot. Default: `-1` - `cache_prompt`: Re-use previously cached prompt from the last request if possible. This may prevent re-caching the prompt from scratch. Default: `false` + `cache_prompt`: Re-use KV cache from a previous request if possible. This way the common prefix does not have to be re-processed, only the suffix that differs between the requests. Because (depending on the backend) the logits are **not** guaranteed to be bit-for-bit identical for different batch sizes (prompt processing vs. token generation) enabling this option can cause nondeterministic results. Default: `false` `system_prompt`: Change the system prompt (initial prompt of all slots), this is useful for chat applications. [See more](#change-system-prompt-on-runtime) From 27615f5ab21060d96953c9c1e223051ab2188f57 Mon Sep 17 00:00:00 2001 From: intelmatt <61025942+intelmatt@users.noreply.github.com> Date: Fri, 7 Jun 2024 05:15:07 -0700 Subject: [PATCH 13/21] cmake : fix BUILD_SHARED_LIBS=ON build (#7784) common depends on pthreads in Linux --- common/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index 0ec8d6d8d..171530c91 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -84,4 +84,4 @@ endif () target_include_directories(${TARGET} PUBLIC .) target_compile_features(${TARGET} PUBLIC cxx_std_11) -target_link_libraries(${TARGET} PRIVATE ${LLAMA_COMMON_EXTRA_LIBS} PUBLIC llama) +target_link_libraries(${TARGET} PRIVATE ${LLAMA_COMMON_EXTRA_LIBS} PUBLIC llama Threads::Threads) From c00fad71e507ff386d42bd74846fe06d19dd63a4 Mon Sep 17 00:00:00 2001 From: Christian Zhou-Zheng <59622928+christianazinn@users.noreply.github.com> Date: Fri, 7 Jun 2024 08:56:01 -0400 Subject: [PATCH 14/21] gguf-split : change binary multi-byte units to decimal (#7803) --- examples/gguf-split/gguf-split.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/gguf-split/gguf-split.cpp b/examples/gguf-split/gguf-split.cpp index e04feeae3..881f0451c 100644 --- a/examples/gguf-split/gguf-split.cpp +++ b/examples/gguf-split/gguf-split.cpp @@ -61,10 +61,10 @@ static size_t split_str_to_n_bytes(std::string str) { int n; if (str.back() == 'M') { sscanf(str.c_str(), "%d", &n); - n_bytes = (size_t)n * 1024 * 1024; // megabytes + n_bytes = (size_t)n * 1000 * 1000; // megabytes } else if (str.back() == 'G') { sscanf(str.c_str(), "%d", &n); - n_bytes = (size_t)n * 1024 * 1024 * 1024; // gigabytes + n_bytes = (size_t)n * 1000 * 1000 * 1000; // gigabytes } else { throw std::invalid_argument("error: supported units are M (megabytes) or G (gigabytes), but got: " + std::string(1, str.back())); } @@ -284,7 +284,7 @@ struct split_strategy { struct ggml_tensor * t = ggml_get_tensor(ctx_meta, gguf_get_tensor_name(ctx_out, i)); total_size += ggml_nbytes(t); } - total_size = total_size / 1024 / 1024; // convert to megabytes + total_size = total_size / 1000 / 1000; // convert to megabytes printf("split %05d: n_tensors = %d, total_size = %ldM\n", i_split + 1, gguf_get_n_tensors(ctx_out), total_size); i_split++; } From da799b41891e34aac86ce4e173f9c4c0afd4fab3 Mon Sep 17 00:00:00 2001 From: slaren Date: Fri, 7 Jun 2024 19:47:49 +0200 Subject: [PATCH 15/21] vulkan : reuse parent extra for views (#7806) * vulkan : reuse parent extra for views * Fix validation error when multiple compute contexts are used in a graph --------- Co-authored-by: 0cc4m --- ggml-vulkan.cpp | 128 +++++++++++++++++++++--------------------------- 1 file changed, 56 insertions(+), 72 deletions(-) diff --git a/ggml-vulkan.cpp b/ggml-vulkan.cpp index e0c512c0d..128769177 100644 --- a/ggml-vulkan.cpp +++ b/ggml-vulkan.cpp @@ -345,15 +345,12 @@ struct vk_context { }; struct ggml_tensor_extra_gpu { - bool ready; - size_t ctx_idx; vk_buffer_ref buffer_gpu; uint64_t offset; void reset() { - ready = false; ctx_idx = 0; buffer_gpu.reset(); offset = 0; @@ -2949,7 +2946,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context * su const uint64_t d_sz = sizeof(float) * d_ne; vk_buffer d_D = extra->buffer_gpu.lock(); - const uint64_t d_buf_offset = extra->offset; + const uint64_t d_buf_offset = extra->offset + dst->view_offs; GGML_ASSERT(d_D != nullptr); GGML_ASSERT(d_D->size >= d_buf_offset + d_sz * ne02 * ne03); vk_buffer d_X; @@ -2958,12 +2955,12 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context * su uint64_t y_buf_offset = 0; if (!src0_uma) { d_Qx = extra_src0->buffer_gpu.lock(); - qx_buf_offset = extra_src0->offset; + qx_buf_offset = extra_src0->offset + src0->view_offs; GGML_ASSERT(d_Qx != nullptr); } if (!src1_uma) { d_Qy = extra_src1->buffer_gpu.lock(); - qy_buf_offset = extra_src1->offset; + qy_buf_offset = extra_src1->offset + src1->view_offs; GGML_ASSERT(d_Qy != nullptr); } if (qx_needs_dequant) { @@ -3114,7 +3111,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context const uint64_t d_sz = sizeof(float) * d_ne; vk_buffer d_D = extra->buffer_gpu.lock(); - const uint64_t d_buf_offset = extra->offset; + const uint64_t d_buf_offset = extra->offset + dst->view_offs; GGML_ASSERT(d_D != nullptr); vk_buffer d_X; uint64_t x_buf_offset = 0; @@ -3122,12 +3119,12 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context uint64_t y_buf_offset = 0; if(!src0_uma) { d_Qx = extra_src0->buffer_gpu.lock(); - qx_buf_offset = extra_src0->offset; + qx_buf_offset = extra_src0->offset + src0->view_offs; GGML_ASSERT(d_Qx != nullptr); } if(!src1_uma) { d_Qy = extra_src1->buffer_gpu.lock(); - qy_buf_offset = extra_src1->offset; + qy_buf_offset = extra_src1->offset + src1->view_offs; GGML_ASSERT(d_Qy != nullptr); } if (qx_needs_dequant) { @@ -3246,14 +3243,14 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c const uint64_t d_sz = sizeof(float) * d_ne; vk_buffer d_D = extra->buffer_gpu.lock(); - const uint64_t d_buf_offset = extra->offset; + const uint64_t d_buf_offset = extra->offset + dst->view_offs; GGML_ASSERT(d_D != nullptr); vk_buffer d_Qx = extra_src0->buffer_gpu.lock(); - const uint64_t qx_buf_offset = extra_src0->offset; + const uint64_t qx_buf_offset = extra_src0->offset + src0->view_offs; GGML_ASSERT(d_Qx != nullptr); if (!src1_uma) { d_Qy = extra_src1->buffer_gpu.lock(); - qy_buf_offset = extra_src1->offset; + qy_buf_offset = extra_src1->offset + src1->view_offs; GGML_ASSERT(d_Qx != nullptr); } @@ -3323,14 +3320,14 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con const uint64_t d_sz = sizeof(float) * d_ne; vk_buffer d_D = extra->buffer_gpu.lock(); - const uint64_t d_buf_offset = extra->offset; + const uint64_t d_buf_offset = extra->offset + dst->view_offs; GGML_ASSERT(d_D != nullptr); vk_buffer d_Qx = extra_src0->buffer_gpu.lock(); - const uint64_t qx_buf_offset = extra_src0->offset; + const uint64_t qx_buf_offset = extra_src0->offset + src0->view_offs; GGML_ASSERT(d_Qx != nullptr); if (!src1_uma) { d_Qy = extra_src1->buffer_gpu.lock(); - qy_buf_offset = extra_src1->offset; + qy_buf_offset = extra_src1->offset + src1->view_offs; GGML_ASSERT(d_Qx != nullptr); } @@ -3459,7 +3456,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context * const uint64_t d_sz = sizeof(float) * d_ne; vk_buffer d_D = extra->buffer_gpu.lock(); - const uint64_t d_buf_offset = extra->offset; + const uint64_t d_buf_offset = extra->offset + dst->view_offs; GGML_ASSERT(d_D != nullptr); vk_buffer d_X; uint64_t x_buf_offset = 0; @@ -3467,17 +3464,17 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context * uint64_t y_buf_offset = 0; if (!src0_uma) { d_Qx = extra_src0->buffer_gpu.lock(); - qx_buf_offset = extra_src0->offset; + qx_buf_offset = extra_src0->offset + src0->view_offs; GGML_ASSERT(d_Qx != nullptr); } if (!src1_uma) { d_Qy = extra_src1->buffer_gpu.lock(); - qy_buf_offset = extra_src1->offset; + qy_buf_offset = extra_src1->offset + src1->view_offs; GGML_ASSERT(d_Qy != nullptr); } if (!ids_uma) { d_ids = extra_ids->buffer_gpu.lock(); - ids_buf_offset = extra_ids->offset; + ids_buf_offset = extra_ids->offset + ids->view_offs; GGML_ASSERT(d_ids != nullptr); } if (qx_needs_dequant) { @@ -3636,7 +3633,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte const uint64_t d_sz = sizeof(float) * d_ne; vk_buffer d_D = extra->buffer_gpu.lock(); - const uint64_t d_buf_offset = extra->offset; + const uint64_t d_buf_offset = extra->offset + dst->view_offs; GGML_ASSERT(d_D != nullptr); vk_buffer d_X; uint64_t x_buf_offset = 0; @@ -3644,17 +3641,17 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte uint64_t y_buf_offset = 0; if(!src0_uma) { d_Qx = extra_src0->buffer_gpu.lock(); - qx_buf_offset = extra_src0->offset; + qx_buf_offset = extra_src0->offset + src0->view_offs; GGML_ASSERT(d_Qx != nullptr); } if(!src1_uma) { d_Qy = extra_src1->buffer_gpu.lock(); - qy_buf_offset = extra_src1->offset; + qy_buf_offset = extra_src1->offset + src1->view_offs; GGML_ASSERT(d_Qy != nullptr); } if(!ids_uma) { d_ids = extra_ids->buffer_gpu.lock(); - ids_buf_offset = extra_ids->offset; + ids_buf_offset = extra_ids->offset + ids->view_offs; GGML_ASSERT(d_ids != nullptr); } if (qx_needs_dequant) { @@ -3769,9 +3766,9 @@ static void ggml_vk_op_repeat(ggml_backend_vk_context * ctx, vk_context * subctx ggml_tensor_extra_gpu * extra_src0 = (ggml_tensor_extra_gpu *) src0->extra; const vk_buffer src_buf = extra_src0->buffer_gpu.lock(); - const uint64_t src_offset = extra_src0->offset; + const uint64_t src_offset = extra_src0->offset + src0->view_offs; vk_buffer dst_buf = extra->buffer_gpu.lock(); - const uint64_t dst_offset = extra->offset; + const uint64_t dst_offset = extra->offset + dst->view_offs; std::vector copies; @@ -4062,21 +4059,21 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c } GGML_ASSERT(d_D != nullptr); - uint64_t d_buf_offset = (extra->offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment; + uint64_t d_buf_offset = ((extra->offset + dst->view_offs) / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment; GGML_ASSERT(d_buf_offset == extra->offset || op == GGML_OP_CPY); // NOLINT if(!src0_uma) { d_X = extra_src0->buffer_gpu.lock(); - x_buf_offset = extra_src0->offset; + x_buf_offset = extra_src0->offset + src0->view_offs; GGML_ASSERT(d_X != nullptr); } if (use_src1 && !src1_uma) { d_Y = extra_src1->buffer_gpu.lock(); - y_buf_offset = extra_src1->offset; + y_buf_offset = extra_src1->offset + src1->view_offs; GGML_ASSERT(d_Y != nullptr); } if (use_src2 && !src2_uma) { d_Z = extra_src2->buffer_gpu.lock(); - z_buf_offset = extra_src2->offset; + z_buf_offset = extra_src2->offset + src2->view_offs; GGML_ASSERT(d_Z != nullptr); } @@ -4336,7 +4333,7 @@ static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context * subctx, cons ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) dst->extra; const uint32_t src0_type_size = ggml_type_size(src0->type); const uint32_t dst_type_size = ggml_type_size(dst->type); - const uint32_t d_offset = (extra->offset % ctx->device->properties.limits.minStorageBufferOffsetAlignment) / dst_type_size; + const uint32_t d_offset = ((extra->offset + dst->view_offs) % ctx->device->properties.limits.minStorageBufferOffsetAlignment) / dst_type_size; ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CPY, { (uint32_t)ggml_nelements(src0), @@ -5569,6 +5566,13 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod const ggml_tensor * src2 = node->src[2]; switch (node->op) { + // Return on empty ops to avoid generating a compute_ctx and setting exit_tensor + case GGML_OP_RESHAPE: + case GGML_OP_VIEW: + case GGML_OP_PERMUTE: + case GGML_OP_TRANSPOSE: + case GGML_OP_NONE: + return; case GGML_OP_UNARY: switch (ggml_get_unary_op(node)) { case GGML_UNARY_OP_SILU: @@ -5590,10 +5594,6 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_CPY: case GGML_OP_CONT: case GGML_OP_DUP: - case GGML_OP_RESHAPE: - case GGML_OP_VIEW: - case GGML_OP_PERMUTE: - case GGML_OP_TRANSPOSE: case GGML_OP_NORM: case GGML_OP_RMS_NORM: case GGML_OP_DIAG_MASK_INF: @@ -5601,7 +5601,6 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_ROPE: case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: - case GGML_OP_NONE: case GGML_OP_ARGSORT: case GGML_OP_SUM_ROWS: break; @@ -5654,12 +5653,6 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_DUP: ggml_vk_cpy(ctx, ctx->compute_ctx, src0, node); - break; - case GGML_OP_RESHAPE: - case GGML_OP_VIEW: - case GGML_OP_PERMUTE: - case GGML_OP_TRANSPOSE: - case GGML_OP_NONE: break; case GGML_OP_NORM: ggml_vk_norm(ctx, ctx->compute_ctx, src0, node); @@ -5712,7 +5705,6 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod return; } - extra->ready = true; extra->ctx_idx = ctx->compute_ctx->idx; #ifdef GGML_VULKAN_CHECK_RESULTS @@ -5796,8 +5788,6 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_compute_ ggml_vk_check_results_0(ctx, params, tensor); #endif - GGML_ASSERT(extra->ready); - vk_context& subctx = ctx->gc.contexts[extra->ctx_idx]; // Only run if ctx hasn't been submitted yet @@ -5822,8 +5812,6 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_compute_ subctx.out_memcpys.clear(); } - extra->ready = false; - return true; } @@ -5943,7 +5931,9 @@ struct ggml_backend_vk_buffer_context { ~ggml_backend_vk_buffer_context() { ggml_vk_destroy_buffer(dev_buffer); - delete[] temp_tensor_extras; + if (temp_tensor_extras != nullptr) { + delete[] temp_tensor_extras; + } } ggml_tensor_extra_gpu * ggml_vk_alloc_temp_tensor_extra() { @@ -5990,18 +5980,16 @@ GGML_CALL static void ggml_backend_vk_buffer_init_tensor(ggml_backend_buffer_t b #endif ggml_backend_vk_buffer_context * ctx = (ggml_backend_vk_buffer_context *)buffer->context; - ggml_tensor_extra_gpu * extra = ctx->ggml_vk_alloc_temp_tensor_extra(); - if (tensor->view_src != nullptr && tensor->view_src->extra != nullptr) { + if (tensor->view_src != nullptr) { GGML_ASSERT(tensor->view_src->buffer->buft == buffer->buft); - ggml_tensor_extra_gpu * extra_view = (ggml_tensor_extra_gpu *) tensor->view_src->extra; - extra->buffer_gpu = extra_view->buffer_gpu; - extra->offset = extra_view->offset + tensor->view_offs; + GGML_ASSERT(tensor->view_src->extra != nullptr); + tensor->extra = tensor->view_src->extra; } else { + ggml_tensor_extra_gpu * extra = ctx->ggml_vk_alloc_temp_tensor_extra(); extra->buffer_gpu = ctx->dev_buffer; extra->offset = (uint8_t *) tensor->data - (uint8_t *) vk_ptr_base; + tensor->extra = extra; } - - tensor->extra = extra; } GGML_CALL static void ggml_backend_vk_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { @@ -6014,7 +6002,7 @@ GGML_CALL static void ggml_backend_vk_buffer_set_tensor(ggml_backend_buffer_t bu vk_buffer buf = extra->buffer_gpu.lock(); - ggml_vk_buffer_write(ctx->ctx, buf, extra->offset + offset, data, size); + ggml_vk_buffer_write(ctx->ctx, buf, extra->offset + tensor->view_offs + offset, data, size); } GGML_CALL static void ggml_backend_vk_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { @@ -6027,7 +6015,7 @@ GGML_CALL static void ggml_backend_vk_buffer_get_tensor(ggml_backend_buffer_t bu vk_buffer buf = extra->buffer_gpu.lock(); - ggml_vk_buffer_read(ctx->ctx, buf, extra->offset + offset, data, size); + ggml_vk_buffer_read(ctx->ctx, buf, extra->offset + tensor->view_offs + offset, data, size); } GGML_CALL static bool ggml_backend_vk_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) { @@ -6038,7 +6026,7 @@ GGML_CALL static bool ggml_backend_vk_buffer_cpy_tensor(ggml_backend_buffer_t bu vk_buffer src_buf = src_extra->buffer_gpu.lock(); vk_buffer dst_buf = dst_extra->buffer_gpu.lock(); - ggml_vk_buffer_copy(dst_buf, dst_extra->offset, src_buf, src_extra->offset, ggml_nbytes(src)); + ggml_vk_buffer_copy(dst_buf, dst_extra->offset + dst->view_offs, src_buf, src_extra->offset + src->view_offs, ggml_nbytes(src)); return true; } @@ -6264,7 +6252,7 @@ GGML_CALL static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, g vk_buffer buf = extra->buffer_gpu.lock(); - ggml_vk_buffer_write_async(ctx, ctx->transfer_ctx, buf, extra->offset + offset, data, size); + ggml_vk_buffer_write_async(ctx, ctx->transfer_ctx, buf, extra->offset + tensor->view_offs + offset, data, size); } GGML_CALL static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { @@ -6284,7 +6272,7 @@ GGML_CALL static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, c vk_buffer buf = extra->buffer_gpu.lock(); - ggml_vk_buffer_read_async(ctx, ctx->transfer_ctx, buf, extra->offset + offset, data, size); + ggml_vk_buffer_read_async(ctx, ctx->transfer_ctx, buf, extra->offset + tensor->view_offs + offset, data, size); } GGML_CALL static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend, const ggml_tensor * src, ggml_tensor * dst) { @@ -6305,7 +6293,7 @@ GGML_CALL static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend, c vk_buffer src_buf = src_extra->buffer_gpu.lock(); vk_buffer dst_buf = dst_extra->buffer_gpu.lock(); - ggml_vk_buffer_copy_async(ctx->transfer_ctx, dst_buf, dst_extra->offset, src_buf, src_extra->offset, ggml_nbytes(src)); + ggml_vk_buffer_copy_async(ctx->transfer_ctx, dst_buf, dst_extra->offset + dst->view_offs, src_buf, src_extra->offset + src->view_offs, ggml_nbytes(src)); return true; } @@ -6478,11 +6466,7 @@ GGML_CALL static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const // return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16; // } break; case GGML_OP_ROPE: - { - const int mode = ((const int32_t *) op->op_params)[2]; - - return true; - } break; + return true; case GGML_OP_NONE: case GGML_OP_RESHAPE: case GGML_OP_VIEW: @@ -6725,7 +6709,7 @@ static void ggml_vk_print_tensor(ggml_backend_vk_context * ctx, const ggml_tenso ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra; vk_buffer buffer_gpu = extra->buffer_gpu.lock(); - ggml_vk_buffer_read(ctx, buffer_gpu, extra->offset, tensor_data, tensor_size); + ggml_vk_buffer_read(ctx, buffer_gpu, extra->offset + tensor->view_offs, tensor_data, tensor_size); } std::cerr << "TENSOR CHECK " << name << " (" << tensor->name << "): " << ggml_op_name(tensor->op) << std::endl; @@ -6809,7 +6793,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_ } else if (ggml_backend_buffer_is_vk(src0->buffer)) { ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) src0->extra; vk_buffer buffer_gpu = extra->buffer_gpu.lock(); - uint64_t offset = extra->offset; + uint64_t offset = extra->offset + src0->view_offs; if (!ggml_is_contiguous(src0) && ggml_vk_dim01_contiguous(src0)) { for (int i3 = 0; i3 < src0->ne[3]; i3++) { for (int i2 = 0; i2 < src0->ne[2]; i2++) { @@ -6851,7 +6835,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_ } else if (ggml_backend_buffer_is_vk(src1->buffer)) { ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) src1->extra; vk_buffer buffer_gpu = extra->buffer_gpu.lock(); - uint64_t offset = extra->offset; + uint64_t offset = extra->offset + src1->view_offs; if (!ggml_is_contiguous(src1) && ggml_vk_dim01_contiguous(src1)) { for (int i3 = 0; i3 < src1->ne[3]; i3++) { for (int i2 = 0; i2 < src1->ne[2]; i2++) { @@ -6909,7 +6893,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_ } else if (ggml_backend_buffer_is_vk(src2->buffer)) { ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) src2->extra; vk_buffer buffer_gpu = extra->buffer_gpu.lock(); - uint64_t offset = extra->offset; + uint64_t offset = extra->offset + src2->view_offs; if (!ggml_is_contiguous(src2) && ggml_vk_dim01_contiguous(src2)) { for (int i3 = 0; i3 < src2->ne[3]; i3++) { for (int i2 = 0; i2 < src2->ne[2]; i2++) { @@ -7092,11 +7076,11 @@ static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_compute_ ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra; vk_buffer buffer_gpu = extra->buffer_gpu.lock(); - if (extra->offset + tensor_size >= buffer_gpu->size) { - tensor_size = buffer_gpu->size - (extra->offset); + if (extra->offset + tensor->view_offs + tensor_size >= buffer_gpu->size) { + tensor_size = buffer_gpu->size - (extra->offset + tensor->view_offs); } - ggml_vk_buffer_read(ctx, buffer_gpu, extra->offset, tensor_data, tensor_size); + ggml_vk_buffer_read(ctx, buffer_gpu, extra->offset + tensor->view_offs, tensor_data, tensor_size); } float first_error_result = -1.0f; From 7a16ce7db2a74a223f0f3b9cee66d4539c5bce8f Mon Sep 17 00:00:00 2001 From: sasha0552 Date: Sat, 8 Jun 2024 07:50:31 +0000 Subject: [PATCH 16/21] server : smart slot selection using Longest Common Prefix (#7728) * server : Smart selection of available slot using Longest Common Substring * add usage * remove trailing whitespaces * Use Longest Common Prefix (LCP) instead of LCS * Rename argument --- common/common.cpp | 10 +++ common/common.h | 2 + examples/server/server.cpp | 138 ++++++++++++++++++++++++++++++++----- examples/server/utils.hpp | 7 ++ 4 files changed, 140 insertions(+), 17 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index cdcb352b5..d2a8bb69e 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1491,6 +1491,14 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.chat_template = argv[i]; return true; } + if (arg == "--slot-prompt-similarity" || arg == "-sps") { + if (++i >= argc) { + invalid_param = true; + return true; + } + params.slot_prompt_similarity = std::stof(argv[i]); + return true; + } if (arg == "-pps") { params.is_pp_shared = true; return true; @@ -1913,6 +1921,8 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param "set custom jinja chat template (default: template taken from model's metadata)\n" "only commonly used templates are accepted:\n" "https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template" }); + options.push_back({ "server", "-sps, --slot-prompt-similarity SIMILARITY", + "how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f, 0.0 = disabled)\n", params.slot_prompt_similarity }); #ifndef LOG_DISABLE_LOGS options.push_back({ "logging" }); diff --git a/common/common.h b/common/common.h index 35f5311e1..038f9084f 100644 --- a/common/common.h +++ b/common/common.h @@ -203,6 +203,8 @@ struct gpt_params { std::string slot_save_path; + float slot_prompt_similarity = 0.5f; + // batched-bench params bool is_pp_shared = false; diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 528220607..6ffaa8d9f 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -647,6 +647,9 @@ struct server_context { server_metrics metrics; + // Necessary similarity of prompt for slot selection + float slot_prompt_similarity = 0.0f; + ~server_context() { if (ctx) { llama_free(ctx); @@ -795,24 +798,88 @@ struct server_context { return prompt_tokens; } - server_slot * get_slot(int id) { - int64_t t_last = ggml_time_us(); - - server_slot * last_used = nullptr; - + server_slot * get_slot_by_id(int id) { for (server_slot & slot : slots) { - if (slot.id == id && slot.available()) { + if (slot.id == id) { return &slot; } - - // among all available slots, find the one that has been least recently used - if (slot.available() && slot.t_last_used < t_last) { - last_used = &slot; - t_last = slot.t_last_used; - } } - return last_used; + return nullptr; + } + + server_slot * get_available_slot(const std::string & prompt) { + server_slot * ret = nullptr; + + // find the slot that has at least n% prompt similarity + if (ret == nullptr && slot_prompt_similarity != 0.0f && !prompt.empty()) { + int max_lcp_len = 0; + float similarity = 0; + + for (server_slot & slot : slots) { + // skip the slot if it is not available + if (!slot.available()) { + continue; + } + + // skip the slot if it does not contains prompt + if (!slot.prompt.is_string()) { + continue; + } + + // current slot's prompt + std::string slot_prompt = slot.prompt.get(); + + // length of the current slot's prompt + int slot_prompt_len = slot_prompt.size(); + + // length of the Longest Common Prefix between the current slot's prompt and the input prompt + int lcp_len = common_part(slot_prompt, prompt); + + // fraction of the common substring length compared to the current slot's prompt length + similarity = static_cast(lcp_len) / slot_prompt_len; + + // select the current slot if the criteria match + if (lcp_len > max_lcp_len && similarity > slot_prompt_similarity) { + max_lcp_len = lcp_len; + ret = &slot; + } + } + + if (ret != nullptr) { + LOG_VERBOSE("selected slot by lcp similarity", { + {"id_slot", ret->id}, + {"max_lcp_len", max_lcp_len}, + {"similarity", similarity}, + }); + } + } + + // find the slot that has been least recently used + if (ret == nullptr) { + int64_t t_last = ggml_time_us(); + for (server_slot & slot : slots) { + // skip the slot if it is not available + if (!slot.available()) { + continue; + } + + // select the current slot if the criteria match + if (slot.t_last_used < t_last) { + t_last = slot.t_last_used; + ret = &slot; + } + } + + if (ret != nullptr) { + LOG_VERBOSE("selected slot by lru", { + {"id_slot", ret->id}, + {"t_last", t_last}, + }); + } + } + + return ret; } bool launch_slot_with_task(server_slot & slot, const server_task & task) { @@ -1515,13 +1582,29 @@ struct server_context { switch (task.type) { case SERVER_TASK_TYPE_COMPLETION: { - server_slot * slot = get_slot(json_value(task.data, "id_slot", -1)); + int id_slot = json_value(task.data, "id_slot", -1); + std::string prompt = json_value(task.data, "prompt", std::string()); + + server_slot * slot; + + if (id_slot != -1) { + slot = get_slot_by_id(id_slot); + } else { + slot = get_available_slot(prompt); + } + if (slot == nullptr) { // if no slot is available, we defer this task for processing later LOG_VERBOSE("no slot is available", {{"id_task", task.id}}); queue_tasks.defer(task); break; } + if (!slot->available()) { + // if requested slot is unavailable, we defer this task for processing later + LOG_VERBOSE("requested slot is unavailable", {{"id_task", task.id}}); + queue_tasks.defer(task); + break; + } if (task.data.contains("system_prompt")) { std::string sys_prompt = json_value(task.data, "system_prompt", std::string()); @@ -1638,11 +1721,17 @@ struct server_context { case SERVER_TASK_TYPE_SLOT_SAVE: { int id_slot = task.data.at("id_slot"); - server_slot * slot = get_slot(id_slot); + server_slot * slot = get_slot_by_id(id_slot); if (slot == nullptr) { send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); break; } + if (!slot->available()) { + // if requested slot is unavailable, we defer this task for processing later + LOG_VERBOSE("requested slot is unavailable", {{"id_task", task.id}}); + queue_tasks.defer(task); + break; + } const size_t token_count = slot->cache_tokens.size(); const int64_t t_start = ggml_time_us(); @@ -1673,11 +1762,17 @@ struct server_context { case SERVER_TASK_TYPE_SLOT_RESTORE: { int id_slot = task.data.at("id_slot"); - server_slot * slot = get_slot(id_slot); + server_slot * slot = get_slot_by_id(id_slot); if (slot == nullptr) { send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); break; } + if (!slot->available()) { + // if requested slot is unavailable, we defer this task for processing later + LOG_VERBOSE("requested slot is unavailable", {{"id_task", task.id}}); + queue_tasks.defer(task); + break; + } const int64_t t_start = ggml_time_us(); @@ -1715,11 +1810,17 @@ struct server_context { case SERVER_TASK_TYPE_SLOT_ERASE: { int id_slot = task.data.at("id_slot"); - server_slot * slot = get_slot(id_slot); + server_slot * slot = get_slot_by_id(id_slot); if (slot == nullptr) { send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); break; } + if (!slot->available()) { + // if requested slot is unavailable, we defer this task for processing later + LOG_VERBOSE("requested slot is unavailable", {{"id_task", task.id}}); + queue_tasks.defer(task); + break; + } // Erase token cache const size_t n_erased = slot->cache_tokens.size(); @@ -2467,6 +2568,9 @@ int main(int argc, char ** argv) { log_data["api_key"] = "api_key: " + std::to_string(params.api_keys.size()) + " keys loaded"; } + // Necessary similarity of prompt for slot selection + ctx_server.slot_prompt_similarity = params.slot_prompt_similarity; + // load the model if (!ctx_server.load_model(params)) { state.store(SERVER_STATE_ERROR); diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index b7bfb41d3..63fde9c9f 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -253,6 +253,13 @@ static size_t common_part(const std::vector & a, const std::vector< return i; } +static size_t common_part(const std::string & a, const std::string & b) { + size_t i; + for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {} + + return i; +} + static bool ends_with(const std::string & str, const std::string & suffix) { return str.size() >= suffix.size() && 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix); } From d4d915d351d1f1270d56184bdd46672893e8a5d8 Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Sat, 8 Jun 2024 20:21:08 +0100 Subject: [PATCH 17/21] url: save -mu downloads to new cache location (#7826) * url: save -mu download to new cache location * url: fs_get_cache_file_path util * url: tweak sig of fs_get_cache_file --- common/common.cpp | 20 ++++++++++++-------- common/common.h | 1 + 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index d2a8bb69e..1591790e6 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -200,19 +200,13 @@ void gpt_params_handle_model_default(gpt_params & params) { } params.hf_file = params.model; } else if (params.model.empty()) { - std::string cache_directory = fs_get_cache_directory(); - const bool success = fs_create_directory_with_parents(cache_directory); - if (!success) { - throw std::runtime_error("failed to create cache directory: " + cache_directory); - } - params.model = cache_directory + string_split(params.hf_file, '/').back(); + params.model = fs_get_cache_file(string_split(params.hf_file, '/').back()); } } else if (!params.model_url.empty()) { if (params.model.empty()) { auto f = string_split(params.model_url, '#').front(); f = string_split(f, '?').front(); - f = string_split(f, '/').back(); - params.model = "models/" + f; + params.model = fs_get_cache_file(string_split(f, '/').back()); } } else if (params.model.empty()) { params.model = DEFAULT_MODEL_PATH; @@ -2279,6 +2273,16 @@ std::string fs_get_cache_directory() { return ensure_trailing_slash(cache_directory); } +std::string fs_get_cache_file(const std::string & filename) { + GGML_ASSERT(filename.find(DIRECTORY_SEPARATOR) == std::string::npos); + std::string cache_directory = fs_get_cache_directory(); + const bool success = fs_create_directory_with_parents(cache_directory); + if (!success) { + throw std::runtime_error("failed to create cache directory: " + cache_directory); + } + return cache_directory + filename; +} + // // Model utils diff --git a/common/common.h b/common/common.h index 038f9084f..2345d855e 100644 --- a/common/common.h +++ b/common/common.h @@ -277,6 +277,7 @@ bool fs_validate_filename(const std::string & filename); bool fs_create_directory_with_parents(const std::string & path); std::string fs_get_cache_directory(); +std::string fs_get_cache_file(const std::string & filename); // // Model utils From fe1e3917cfa0f9397a765cfd0aef880674d938d5 Mon Sep 17 00:00:00 2001 From: slaren Date: Sun, 9 Jun 2024 01:43:39 +0200 Subject: [PATCH 18/21] Revert "[SYCL] Update rpc-server.cpp to include SYCL backend (#7682)" (#7808) This reverts commit 9422c5e34bbd302493b77a8f6d546154a1f4fe82. --- examples/rpc/rpc-server.cpp | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/examples/rpc/rpc-server.cpp b/examples/rpc/rpc-server.cpp index 62d828250..7c15d2aa4 100644 --- a/examples/rpc/rpc-server.cpp +++ b/examples/rpc/rpc-server.cpp @@ -6,10 +6,6 @@ #include "ggml-metal.h" #endif -#ifdef GGML_USE_SYCL -#include "ggml-sycl.h" -#endif - #include "ggml-rpc.h" #ifdef _WIN32 # include @@ -83,12 +79,6 @@ static ggml_backend_t create_backend() { if (!backend) { fprintf(stderr, "%s: ggml_backend_metal_init() failed\n", __func__); } -#elif GGML_USE_SYCL - fprintf(stderr, "%s: using SYCL backend\n", __func__); - backend = ggml_backend_sycl_init(0); // init device 0 - if (!backend) { - fprintf(stderr, "%s: ggml_backend_sycl_init() failed\n", __func__); - } #endif // if there aren't GPU Backends fallback to CPU backend From ed9f2521185706481501a5e6d5315397b11802ff Mon Sep 17 00:00:00 2001 From: compilade Date: Sat, 8 Jun 2024 22:34:29 -0400 Subject: [PATCH 19/21] gguf-py : decouple adding metadata from writing in GGUFWriter (#7827) Main changes of this PR is to consolidate GGUFWriter.add_key and GGUFWriter.add_val into GGUFWriter.add_key_value. In addition use_temp_file is now opt-in instead of opt-out defaulting to False. Also GGUFWriter now does not require output file name until when actually writing to it. And GGUFWriter doesn't really need to eagerly prepare the data layout of the metadata --- convert-hf-to-gguf.py | 8 +- gguf-py/gguf/gguf_writer.py | 270 +++++++++++++++------------ gguf-py/scripts/gguf-new-metadata.py | 6 +- 3 files changed, 160 insertions(+), 124 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index a86864f04..0327712d7 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -47,7 +47,7 @@ class Model: _model_classes: dict[str, type[Model]] = {} dir_model: Path - ftype: int + ftype: gguf.LlamaFileType is_big_endian: bool endianess: gguf.GGUFEndian use_temp_file: bool @@ -94,7 +94,7 @@ class Model: ftype_lw: str = ftype_up.lower() # allow templating the file name with the output ftype, useful with the "auto" ftype self.fname_out = fname_out.parent / fname_out.name.format(ftype_lw, outtype=ftype_lw, ftype=ftype_lw, OUTTYPE=ftype_up, FTYPE=ftype_up) - self.gguf_writer = gguf.GGUFWriter(self.fname_out, gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file) + self.gguf_writer = gguf.GGUFWriter(path=None, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file) @classmethod def __init_subclass__(cls): @@ -324,13 +324,13 @@ class Model: def write(self): self.write_tensors() - self.gguf_writer.write_header_to_file() + self.gguf_writer.write_header_to_file(self.fname_out) self.gguf_writer.write_kv_data_to_file() self.gguf_writer.write_tensors_to_file(progress=True) self.gguf_writer.close() def write_vocab(self): - self.gguf_writer.write_header_to_file() + self.gguf_writer.write_header_to_file(self.fname_out) self.gguf_writer.write_kv_data_to_file() self.gguf_writer.close() diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index b93747aff..ed56abfb3 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -5,6 +5,7 @@ import os import shutil import struct import tempfile +from dataclasses import dataclass from enum import Enum, auto from io import BufferedWriter from typing import IO, Any, Sequence, Mapping @@ -30,17 +31,36 @@ from .quants import quant_shape_from_byte_shape logger = logging.getLogger(__name__) +@dataclass +class TensorInfo: + shape: Sequence[int] + dtype: GGMLQuantizationType + nbytes: int + tensor: np.ndarray[Any, Any] | None = None + + +@dataclass +class GGUFValue: + value: Any + type: GGUFValueType + + class WriterState(Enum): + NO_FILE = auto() EMPTY = auto() HEADER = auto() KV_DATA = auto() TI_DATA = auto() + WEIGHTS = auto() class GGUFWriter: - fout: BufferedWriter + fout: BufferedWriter | None + path: os.PathLike[str] | str | None temp_file: tempfile.SpooledTemporaryFile[bytes] | None - tensors: list[np.ndarray[Any, Any]] + tensors: dict[str, TensorInfo] + kv_data: dict[str, GGUFValue] + state: WriterState _simple_value_packing = { GGUFValueType.UINT8: "B", GGUFValueType.INT8: "b", @@ -56,141 +76,140 @@ class GGUFWriter: } def __init__( - self, path: os.PathLike[str] | str, arch: str, use_temp_file: bool = True, + self, path: os.PathLike[str] | str | None, arch: str, use_temp_file: bool = False, endianess: GGUFEndian = GGUFEndian.LITTLE, ): - self.fout = open(path, "wb") + self.fout = None + self.path = path self.arch = arch self.endianess = endianess - self.offset_tensor = 0 self.data_alignment = GGUF_DEFAULT_ALIGNMENT - self.kv_data = bytearray() - self.kv_data_count = 0 - self.ti_data = bytearray() - self.ti_data_count = 0 - self.ti_names = set() self.use_temp_file = use_temp_file self.temp_file = None - self.tensors = [] + self.tensors = dict() + self.kv_data = dict() logger.info("gguf: This GGUF file is for {0} Endian only".format( "Big" if self.endianess == GGUFEndian.BIG else "Little", )) - self.state = WriterState.EMPTY + self.state = WriterState.NO_FILE self.add_architecture() - def write_header_to_file(self) -> None: + def open_output_file(self, path: os.PathLike[str] | str | None = None) -> None: + if self.state is WriterState.EMPTY and self.fout is not None and (path is None or path == self.path): + # allow calling this multiple times as long as the path is the same + return + if self.state is not WriterState.NO_FILE: + raise ValueError(f'Expected output file to be not yet opened, got {self.state}') + + if path is not None: + self.path = path + + if self.path is not None: + if self.fout is not None: + self.fout.close() + self.fout = open(self.path, "wb") + self.state = WriterState.EMPTY + + def write_header_to_file(self, path: os.PathLike[str] | str | None = None) -> None: + self.open_output_file(path) + if self.state is not WriterState.EMPTY: raise ValueError(f'Expected output file to be empty, got {self.state}') self._write_packed(" None: if self.state is not WriterState.HEADER: raise ValueError(f'Expected output file to contain the header, got {self.state}') + assert self.fout is not None - self.fout.write(self.kv_data) + kv_data = bytearray() + + for key, val in self.kv_data.items(): + kv_data += self._pack_val(key, GGUFValueType.STRING, add_vtype=False) + kv_data += self._pack_val(val.value, val.type, add_vtype=True) + + self.fout.write(kv_data) self.flush() self.state = WriterState.KV_DATA def write_ti_data_to_file(self) -> None: if self.state is not WriterState.KV_DATA: raise ValueError(f'Expected output file to contain KV data, got {self.state}') + assert self.fout is not None - self.fout.write(self.ti_data) + ti_data = bytearray() + offset_tensor = 0 + + for name, ti in self.tensors.items(): + ti_data += self._pack_val(name, GGUFValueType.STRING, add_vtype=False) + n_dims = len(ti.shape) + ti_data += self._pack("I", n_dims) + for i in range(n_dims): + ti_data += self._pack("Q", ti.shape[n_dims - 1 - i]) + ti_data += self._pack("I", ti.dtype) + ti_data += self._pack("Q", offset_tensor) + offset_tensor += GGUFWriter.ggml_pad(ti.nbytes, self.data_alignment) + + self.fout.write(ti_data) self.flush() self.state = WriterState.TI_DATA - def add_key(self, key: str) -> None: - self.add_val(key, GGUFValueType.STRING, add_vtype=False) + def add_key_value(self, key: str, val: Any, vtype: GGUFValueType) -> None: + if key in self.kv_data: + raise ValueError(f'Duplicated key name {key!r}') + + self.kv_data[key] = GGUFValue(value=val, type=vtype) def add_uint8(self, key: str, val: int) -> None: - self.add_key(key) - self.add_val(val, GGUFValueType.UINT8) + self.add_key_value(key,val, GGUFValueType.UINT8) def add_int8(self, key: str, val: int) -> None: - self.add_key(key) - self.add_val(val, GGUFValueType.INT8) + self.add_key_value(key, val, GGUFValueType.INT8) def add_uint16(self, key: str, val: int) -> None: - self.add_key(key) - self.add_val(val, GGUFValueType.UINT16) + self.add_key_value(key, val, GGUFValueType.UINT16) def add_int16(self, key: str, val: int) -> None: - self.add_key(key) - self.add_val(val, GGUFValueType.INT16) + self.add_key_value(key, val, GGUFValueType.INT16) def add_uint32(self, key: str, val: int) -> None: - self.add_key(key) - self.add_val(val, GGUFValueType.UINT32) + self.add_key_value(key, val, GGUFValueType.UINT32) def add_int32(self, key: str, val: int) -> None: - self.add_key(key) - self.add_val(val, GGUFValueType.INT32) + self.add_key_value(key, val, GGUFValueType.INT32) def add_float32(self, key: str, val: float) -> None: - self.add_key(key) - self.add_val(val, GGUFValueType.FLOAT32) + self.add_key_value(key, val, GGUFValueType.FLOAT32) def add_uint64(self, key: str, val: int) -> None: - self.add_key(key) - self.add_val(val, GGUFValueType.UINT64) + self.add_key_value(key, val, GGUFValueType.UINT64) def add_int64(self, key: str, val: int) -> None: - self.add_key(key) - self.add_val(val, GGUFValueType.INT64) + self.add_key_value(key, val, GGUFValueType.INT64) def add_float64(self, key: str, val: float) -> None: - self.add_key(key) - self.add_val(val, GGUFValueType.FLOAT64) + self.add_key_value(key, val, GGUFValueType.FLOAT64) def add_bool(self, key: str, val: bool) -> None: - self.add_key(key) - self.add_val(val, GGUFValueType.BOOL) + self.add_key_value(key, val, GGUFValueType.BOOL) def add_string(self, key: str, val: str) -> None: if not val: return - self.add_key(key) - self.add_val(val, GGUFValueType.STRING) + self.add_key_value(key, val, GGUFValueType.STRING) def add_array(self, key: str, val: Sequence[Any]) -> None: if not isinstance(val, Sequence): raise ValueError("Value must be a sequence for array type") - self.add_key(key) - self.add_val(val, GGUFValueType.ARRAY) - - def add_val(self, val: Any, vtype: GGUFValueType | None = None, add_vtype: bool = True) -> None: - if vtype is None: - vtype = GGUFValueType.get_type(val) - - if add_vtype: - self.kv_data += self._pack("I", vtype) - self.kv_data_count += 1 - - pack_fmt = self._simple_value_packing.get(vtype) - if pack_fmt is not None: - self.kv_data += self._pack(pack_fmt, val, skip_pack_prefix = vtype == GGUFValueType.BOOL) - elif vtype == GGUFValueType.STRING: - encoded_val = val.encode("utf-8") if isinstance(val, str) else val - self.kv_data += self._pack("Q", len(encoded_val)) - self.kv_data += encoded_val - elif vtype == GGUFValueType.ARRAY and isinstance(val, Sequence) and val: - ltype = GGUFValueType.get_type(val[0]) - if not all(GGUFValueType.get_type(i) is ltype for i in val[1:]): - raise ValueError("All items in a GGUF array should be of the same type") - self.kv_data += self._pack("I", ltype) - self.kv_data += self._pack("Q", len(val)) - for item in val: - self.add_val(item, add_vtype=False) - else: - raise ValueError("Invalid GGUF metadata value type or value") + self.add_key_value(key, val, GGUFValueType.ARRAY) @staticmethod def ggml_pad(x: int, n: int) -> int: @@ -200,16 +219,12 @@ class GGUFWriter: self, name: str, tensor_shape: Sequence[int], tensor_dtype: np.dtype, tensor_nbytes: int, raw_dtype: GGMLQuantizationType | None = None, ) -> None: - if self.state is not WriterState.EMPTY: - raise ValueError(f'Expected output file to be empty, got {self.state}') + if self.state is not WriterState.NO_FILE: + raise ValueError(f'Expected output file to be not yet opened, got {self.state}') - if name in self.ti_names: - raise ValueError(f'Duplicated tensor name {name}') - self.ti_names.add(name) + if name in self.tensors: + raise ValueError(f'Duplicated tensor name {name!r}') - encoded_name = name.encode("utf-8") - self.ti_data += self._pack("Q", len(encoded_name)) - self.ti_data += encoded_name if raw_dtype is None: if tensor_dtype == np.float16: dtype = GGMLQuantizationType.F16 @@ -231,14 +246,8 @@ class GGUFWriter: dtype = raw_dtype if tensor_dtype == np.uint8: tensor_shape = quant_shape_from_byte_shape(tensor_shape, raw_dtype) - n_dims = len(tensor_shape) - self.ti_data += self._pack("I", n_dims) - for i in range(n_dims): - self.ti_data += self._pack("Q", tensor_shape[n_dims - 1 - i]) - self.ti_data += self._pack("I", dtype) - self.ti_data += self._pack("Q", self.offset_tensor) - self.offset_tensor += GGUFWriter.ggml_pad(tensor_nbytes, self.data_alignment) - self.ti_data_count += 1 + + self.tensors[name] = TensorInfo(shape=tensor_shape, dtype=dtype, nbytes=tensor_nbytes) def add_tensor( self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None, @@ -252,10 +261,10 @@ class GGUFWriter: self.temp_file = fp shape: Sequence[int] = raw_shape if raw_shape is not None else tensor.shape - self.add_tensor_info(name, shape, tensor.dtype, tensor.nbytes, raw_dtype = raw_dtype) + self.add_tensor_info(name, shape, tensor.dtype, tensor.nbytes, raw_dtype=raw_dtype) if self.temp_file is None: - self.tensors.append(tensor) + self.tensors[name].tensor = tensor return tensor.tofile(self.temp_file) @@ -267,8 +276,9 @@ class GGUFWriter: fp.write(bytes([0] * pad)) def write_tensor_data(self, tensor: np.ndarray[Any, Any]) -> None: - if self.state is not WriterState.TI_DATA: - raise ValueError(f'Expected output file to contain tensor info, got {self.state}') + if self.state is not WriterState.TI_DATA and self.state is not WriterState.WEIGHTS: + raise ValueError(f'Expected output file to contain tensor info or weights, got {self.state}') + assert self.fout is not None if self.endianess == GGUFEndian.BIG: tensor.byteswap(inplace=True) @@ -276,50 +286,51 @@ class GGUFWriter: tensor.tofile(self.fout) self.write_padding(self.fout, tensor.nbytes) + self.state = WriterState.WEIGHTS + def write_tensors_to_file(self, *, progress: bool = False) -> None: self.write_ti_data_to_file() + assert self.fout is not None + self.write_padding(self.fout, self.fout.tell()) if self.temp_file is None: - self.tensors.reverse() # to pop from the "beginning" in constant time + bar = None if progress: from tqdm import tqdm - total_bytes = sum(t.nbytes for t in self.tensors) + total_bytes = sum(t.nbytes for t in self.tensors.values()) bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True) - while True: - try: - tensor = self.tensors.pop() - except IndexError: - break - tensor.tofile(self.fout) - bar.update(tensor.nbytes) - self.write_padding(self.fout, tensor.nbytes) - return - while True: - try: - tensor = self.tensors.pop() - except IndexError: - break - tensor.tofile(self.fout) - self.write_padding(self.fout, tensor.nbytes) - return + # relying on the fact that Python dicts preserve insertion order (since 3.7) + for ti in self.tensors.values(): + assert ti.tensor is not None # can only iterate once over the tensors + assert ti.tensor.nbytes == ti.nbytes + ti.tensor.tofile(self.fout) + if bar is not None: + bar.update(ti.nbytes) + self.write_padding(self.fout, ti.nbytes) + ti.tensor = None + else: + self.temp_file.seek(0) - self.temp_file.seek(0) + shutil.copyfileobj(self.temp_file, self.fout) + self.flush() + self.temp_file.close() - shutil.copyfileobj(self.temp_file, self.fout) - self.flush() - self.temp_file.close() + self.state = WriterState.WEIGHTS def flush(self) -> None: + assert self.fout is not None self.fout.flush() def close(self) -> None: - self.fout.close() + if self.fout is not None: + self.fout.close() + self.fout = None def add_architecture(self) -> None: self.add_string(Keys.General.ARCHITECTURE, self.arch) @@ -449,7 +460,7 @@ class GGUFWriter: def add_rope_scaling_factor(self, value: float) -> None: self.add_float32(Keys.Rope.SCALING_FACTOR.format(arch=self.arch), value) - def add_rope_scaling_attn_factors(self, value: Sequence[float]) -> None: + def add_rope_scaling_attn_factors(self, value: float) -> None: self.add_float32(Keys.Rope.SCALING_ATTN_FACTOR.format(arch=self.arch), value) def add_rope_scaling_orig_ctx_len(self, value: int) -> None: @@ -571,5 +582,32 @@ class GGUFWriter: pack_prefix = '<' if self.endianess == GGUFEndian.LITTLE else '>' return struct.pack(f'{pack_prefix}{fmt}', value) + def _pack_val(self, val: Any, vtype: GGUFValueType, add_vtype: bool) -> bytes: + kv_data = bytearray() + + if add_vtype: + kv_data += self._pack("I", vtype) + + pack_fmt = self._simple_value_packing.get(vtype) + if pack_fmt is not None: + kv_data += self._pack(pack_fmt, val, skip_pack_prefix = vtype == GGUFValueType.BOOL) + elif vtype == GGUFValueType.STRING: + encoded_val = val.encode("utf-8") if isinstance(val, str) else val + kv_data += self._pack("Q", len(encoded_val)) + kv_data += encoded_val + elif vtype == GGUFValueType.ARRAY and isinstance(val, Sequence) and val: + ltype = GGUFValueType.get_type(val[0]) + if not all(GGUFValueType.get_type(i) is ltype for i in val[1:]): + raise ValueError("All items in a GGUF array should be of the same type") + kv_data += self._pack("I", ltype) + kv_data += self._pack("Q", len(val)) + for item in val: + kv_data += self._pack_val(item, ltype, add_vtype=False) + else: + raise ValueError("Invalid GGUF metadata value type or value") + + return kv_data + def _write_packed(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> None: + assert self.fout is not None self.fout.write(self._pack(fmt, value, skip_pack_prefix)) diff --git a/gguf-py/scripts/gguf-new-metadata.py b/gguf-py/scripts/gguf-new-metadata.py index 21e91180c..c4b90d581 100755 --- a/gguf-py/scripts/gguf-new-metadata.py +++ b/gguf-py/scripts/gguf-new-metadata.py @@ -101,8 +101,7 @@ def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new logger.debug(f'Copying {field.name}') if val.value is not None: - writer.add_key(field.name) - writer.add_val(val.value, val.type) + writer.add_key_value(field.name, val.value, val.type) if gguf.Keys.Tokenizer.CHAT_TEMPLATE in new_metadata: logger.debug('Adding chat template(s)') @@ -111,8 +110,7 @@ def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new for key, val in new_metadata.items(): logger.debug(f'Adding {key}: "{val.value}" {val.description}') - writer.add_key(key) - writer.add_val(val.value, val.type) + writer.add_key_value(key, val.value, val.type) total_bytes = 0 From 5795b941827fdec6c1662986de962badff456718 Mon Sep 17 00:00:00 2001 From: compilade Date: Sat, 8 Jun 2024 22:47:25 -0400 Subject: [PATCH 20/21] convert-hf : match model part name prefix and suffix (#7687) In #7075, to fix the conversion of (some) models using model-00001-of-00001.safetensors instead of model.safetensors for a single model part we simply used the same logic as the part count to get the part names. But this doesn't always work correctly, like when unusual additional model files like consolidated.safetensors in https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3 are present. This commit matching both the prefix and the suffix of the model part names should fix this problem without breaking any previously-supported upstream models. But according to report by @teleprint-me there is still some persistent problem, but shall do in the meantime. --- convert-hf-to-gguf.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 0327712d7..b38f48edf 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -73,10 +73,10 @@ class Model: self.endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE self.use_temp_file = use_temp_file self.lazy = not eager - self.part_names = Model.get_model_part_names(self.dir_model, ".safetensors") + self.part_names = Model.get_model_part_names(self.dir_model, "model", ".safetensors") self.is_safetensors = len(self.part_names) > 0 if not self.is_safetensors: - self.part_names = Model.get_model_part_names(self.dir_model, ".bin") + self.part_names = Model.get_model_part_names(self.dir_model, "pytorch_model", ".bin") self.hparams = Model.load_hparams(self.dir_model) self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer"]) self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) @@ -335,10 +335,10 @@ class Model: self.gguf_writer.close() @staticmethod - def get_model_part_names(dir_model: Path, suffix: str) -> list[str]: + def get_model_part_names(dir_model: Path, prefix: str, suffix: str) -> list[str]: part_names: list[str] = [] for filename in os.listdir(dir_model): - if filename.endswith(suffix): + if filename.startswith(prefix) and filename.endswith(suffix): part_names.append(filename) part_names.sort() From 2decf57bc6e4a6b45176c3727d964a01161beecc Mon Sep 17 00:00:00 2001 From: sasha0552 Date: Sun, 9 Jun 2024 06:39:25 +0000 Subject: [PATCH 21/21] convert-hf : set the model name based on cli arg, if present (#7693) `--model-name` argument was added a while ago but did not do anything. This commit fixes this issue and enables this feature. --- convert-hf-to-gguf.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index b38f48edf..025405a2c 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -52,6 +52,7 @@ class Model: endianess: gguf.GGUFEndian use_temp_file: bool lazy: bool + model_name: str | None part_names: list[str] is_safetensors: bool hparams: dict[str, Any] @@ -64,7 +65,7 @@ class Model: # subclasses should define this! model_arch: gguf.MODEL_ARCH - def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool, use_temp_file: bool, eager: bool): + def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool, use_temp_file: bool, eager: bool, model_name: str | None): if type(self) is Model: raise TypeError(f"{type(self).__name__!r} should not be directly instantiated") self.dir_model = dir_model @@ -73,6 +74,7 @@ class Model: self.endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE self.use_temp_file = use_temp_file self.lazy = not eager + self.model_name = model_name self.part_names = Model.get_model_part_names(self.dir_model, "model", ".safetensors") self.is_safetensors = len(self.part_names) > 0 if not self.is_safetensors: @@ -182,7 +184,7 @@ class Model: return new_name def set_gguf_parameters(self): - self.gguf_writer.add_name(self.dir_model.name) + self.gguf_writer.add_name(self.dir_model.name if self.model_name is None else self.model_name) self.gguf_writer.add_block_count(self.block_count) if (n_ctx := self.find_hparam(["max_position_embeddings", "n_ctx"], optional=True)) is not None: @@ -665,7 +667,7 @@ class GPTNeoXModel(Model): def set_gguf_parameters(self): block_count = self.hparams["num_hidden_layers"] - self.gguf_writer.add_name(self.dir_model.name) + self.gguf_writer.add_name(self.dir_model.name if self.model_name is None else self.model_name) self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"]) self.gguf_writer.add_embedding_length(self.hparams["hidden_size"]) self.gguf_writer.add_block_count(block_count) @@ -798,7 +800,7 @@ class MPTModel(Model): def set_gguf_parameters(self): block_count = self.hparams["n_layers"] - self.gguf_writer.add_name(self.dir_model.name) + self.gguf_writer.add_name(self.dir_model.name if self.model_name is None else self.model_name) self.gguf_writer.add_context_length(self.hparams["max_seq_len"]) self.gguf_writer.add_embedding_length(self.hparams["d_model"]) self.gguf_writer.add_block_count(block_count) @@ -850,7 +852,7 @@ class OrionModel(Model): raise ValueError("gguf: can not find ctx length parameter.") self.gguf_writer.add_file_type(self.ftype) - self.gguf_writer.add_name(self.dir_model.name) + self.gguf_writer.add_name(self.dir_model.name if self.model_name is None else self.model_name) self.gguf_writer.add_source_hf_repo(hf_repo) self.gguf_writer.add_tensor_data_layout("Meta AI original pth") self.gguf_writer.add_context_length(ctx_length) @@ -887,7 +889,7 @@ class BaichuanModel(Model): else: raise ValueError("gguf: can not find ctx length parameter.") - self.gguf_writer.add_name(self.dir_model.name) + self.gguf_writer.add_name(self.dir_model.name if self.model_name is None else self.model_name) self.gguf_writer.add_source_hf_repo(hf_repo) self.gguf_writer.add_tensor_data_layout("Meta AI original pth") self.gguf_writer.add_context_length(ctx_length) @@ -1010,7 +1012,7 @@ class XverseModel(Model): else: raise ValueError("gguf: can not find ctx length parameter.") - self.gguf_writer.add_name(self.dir_model.name) + self.gguf_writer.add_name(self.dir_model.name if self.model_name is None else self.model_name) self.gguf_writer.add_source_hf_repo(hf_repo) self.gguf_writer.add_tensor_data_layout("Meta AI original pth") self.gguf_writer.add_context_length(ctx_length) @@ -1206,7 +1208,7 @@ class StableLMModel(Model): hparams = self.hparams block_count = hparams["num_hidden_layers"] - self.gguf_writer.add_name(self.dir_model.name) + self.gguf_writer.add_name(self.dir_model.name if self.model_name is None else self.model_name) self.gguf_writer.add_context_length(hparams["max_position_embeddings"]) self.gguf_writer.add_embedding_length(hparams["hidden_size"]) self.gguf_writer.add_block_count(block_count) @@ -1681,7 +1683,7 @@ class GPT2Model(Model): model_arch = gguf.MODEL_ARCH.GPT2 def set_gguf_parameters(self): - self.gguf_writer.add_name(self.dir_model.name) + self.gguf_writer.add_name(self.dir_model.name if self.model_name is None else self.model_name) self.gguf_writer.add_block_count(self.hparams["n_layer"]) self.gguf_writer.add_context_length(self.hparams["n_ctx"]) self.gguf_writer.add_embedding_length(self.hparams["n_embd"]) @@ -2248,7 +2250,7 @@ class GemmaModel(Model): hparams = self.hparams block_count = hparams["num_hidden_layers"] - self.gguf_writer.add_name(self.dir_model.name) + self.gguf_writer.add_name(self.dir_model.name if self.model_name is None else self.model_name) self.gguf_writer.add_context_length(hparams["max_position_embeddings"]) self.gguf_writer.add_embedding_length(hparams["hidden_size"]) self.gguf_writer.add_block_count(block_count) @@ -2348,7 +2350,7 @@ class MambaModel(Model): # Fail early for models which don't have a block expansion factor of 2 assert d_inner == 2 * d_model - self.gguf_writer.add_name(self.dir_model.name) + self.gguf_writer.add_name(self.dir_model.name if self.model_name is None else self.model_name) self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default self.gguf_writer.add_embedding_length(d_model) self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading @@ -2852,7 +2854,7 @@ def main() -> None: logger.error(f"Model {hparams['architectures'][0]} is not supported") sys.exit(1) - model_instance = model_class(dir_model, ftype_map[args.outtype], fname_out, args.bigendian, args.use_temp_file, args.no_lazy) + model_instance = model_class(dir_model, ftype_map[args.outtype], fname_out, args.bigendian, args.use_temp_file, args.no_lazy, args.model_name) logger.info("Set model parameters") model_instance.set_gguf_parameters()