diff --git a/common/sampling.cpp b/common/sampling.cpp index df1b26a90..47e446a8d 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -35,6 +35,8 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_ result->prev.resize(params.n_prev); + result->n_considered = 0; + llama_sampling_set_rng_seed(result, params.seed); return result; @@ -64,6 +66,7 @@ void llama_sampling_reset(llama_sampling_context * ctx) { std::fill(ctx->prev.begin(), ctx->prev.end(), 0); ctx->cur.clear(); + ctx->n_considered = 0; } void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed) { @@ -253,6 +256,8 @@ static llama_token llama_sampling_sample_impl( } } + ctx_sampling->n_considered = cur_p.size; + return id; } diff --git a/common/sampling.h b/common/sampling.h index df4dbbc3b..5cabee561 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -82,6 +82,7 @@ struct llama_sampling_context { // TODO: replace with ring-buffer std::vector prev; std::vector cur; + size_t n_considered; std::mt19937 rng; }; diff --git a/convert-hf-to-gguf-update.py b/convert-hf-to-gguf-update.py index 46a225462..ae901e24c 100755 --- a/convert-hf-to-gguf-update.py +++ b/convert-hf-to-gguf-update.py @@ -67,6 +67,7 @@ models = [ {"name": "gpt-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/openai-community/gpt2", }, {"name": "refact", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/smallcloudai/Refact-1_6-base", }, {"name": "command-r", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/CohereForAI/c4ai-command-r-v01", }, + {"name": "olmo", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/allenai/OLMo-1.7-7B-hf", }, ] # make directory "models/tokenizers" if it doesn't exist diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index f7441e6b8..f65d9320e 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -314,6 +314,9 @@ class Model(ABC): if chkhsh == "9c2227e4dd922002fb81bde4fc02b0483ca4f12911410dee2255e4987644e3f8": # ref: https://huggingface.co/CohereForAI/c4ai-command-r-v01 res = "command-r" + if chkhsh == "b6dc8df998e1cfbdc4eac8243701a65afe638679230920b50d6f17d81c098166": + # ref: https://huggingface.co/allenai/OLMo-1.7-7B-hf + res = "olmo" if res is None: logger.warning("\n") @@ -2831,8 +2834,9 @@ class OlmoModel(Model): def set_gguf_parameters(self): super().set_gguf_parameters() self.gguf_writer.add_layer_norm_eps(1e-5) - if "clip_qkv" in self.hparams is not None: - self.gguf_writer.add_clamp_kqv(self.hparams["clip_qkv"]) + clip_qkv = self.hparams.get("clip_qkv") + if clip_qkv is not None: + self.gguf_writer.add_clamp_kqv(clip_qkv) # Same as super class, but permuting q_proj, k_proj # Copied from: LlamaModel diff --git a/docs/HOWTO-add-model.md b/docs/HOWTO-add-model.md index a56b78344..48769cdf6 100644 --- a/docs/HOWTO-add-model.md +++ b/docs/HOWTO-add-model.md @@ -96,9 +96,9 @@ NOTE: The dimensions in `ggml` are typically in the reverse order of the `pytorc This is the funniest part, you have to provide the inference graph implementation of the new model architecture in `llama_build_graph`. -Have a look to existing implementation like `build_llama`, `build_dbrx` or `build_bert`. +Have a look at existing implementation like `build_llama`, `build_dbrx` or `build_bert`. -When implementing a new graph, please note that the underlying `ggml` backends might not support them all, support of missing backend operations can be added in another PR. +When implementing a new graph, please note that the underlying `ggml` backends might not support them all, support for missing backend operations can be added in another PR. Note: to debug the inference graph: you can use [eval-callback](../examples/eval-callback). diff --git a/examples/imatrix/imatrix.cpp b/examples/imatrix/imatrix.cpp index 84037c96d..da7cfeaee 100644 --- a/examples/imatrix/imatrix.cpp +++ b/examples/imatrix/imatrix.cpp @@ -20,6 +20,7 @@ struct Stats { std::vector values; + std::vector counts; int ncall = 0; }; @@ -122,12 +123,10 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * auto & e = m_stats[wname]; ++e.ncall; - // NOTE: since we select top-k experts, the number of calls for the expert tensors will be k times larger - // using the following line, we can correct for that if needed by replacing the line above with: - //if (idx == t->src[0]->ne[0] - 1) ++e.ncall; if (e.values.empty()) { e.values.resize(src1->ne[0]*n_as, 0); + e.counts.resize(src1->ne[0]*n_as, 0); } else if (e.values.size() != (size_t)src1->ne[0]*n_as) { fprintf(stderr, "Oops: inconsistent size for %s (%d vs %d)\n", wname.c_str(), (int)e.values.size(), (int)src1->ne[0]*n_as); @@ -154,6 +153,7 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * for (int j = 0; j < (int)src1->ne[0]; ++j) { e.values[e_start + j] += x[j]*x[j]; + e.counts[e_start + j]++; } } } @@ -171,6 +171,7 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * auto& e = m_stats[wname]; if (e.values.empty()) { e.values.resize(src1->ne[0], 0); + e.counts.resize(src1->ne[0], 0); } else if (e.values.size() != (size_t)src1->ne[0]) { fprintf(stderr, "Oops: inconsistent size for %s (%d vs %d)\n", wname.c_str(), (int)e.values.size(), (int)src1->ne[0]); @@ -184,6 +185,7 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * const float * x = data + row * src1->ne[0]; for (int j = 0; j < (int)src1->ne[0]; ++j) { e.values[j] += x[j]*x[j]; + e.counts[j]++; } } if (e.ncall > m_last_call) { @@ -223,7 +225,13 @@ void IMatrixCollector::save_imatrix(const char * fname, const char * dataset) co out.write((const char *) &p.second.ncall, sizeof(p.second.ncall)); int nval = p.second.values.size(); out.write((const char *) &nval, sizeof(nval)); - if (nval > 0) out.write((const char *) p.second.values.data(), nval * sizeof(float)); + if (nval > 0) { + std::vector tmp(nval); + for (int i = 0; i < nval; i++) { + tmp[i] = (p.second.values[i] / static_cast(p.second.counts[i])) * static_cast(p.second.ncall); + } + out.write((const char*)tmp.data(), nval*sizeof(float)); + } } // Write the number of call the matrix was computed with @@ -271,14 +279,28 @@ bool IMatrixCollector::load_imatrix(const char * imatrix_file, std::unordered_ma imatrix_data = {}; return false; } - e.values.resize(nval); - in.read((char*)e.values.data(), nval*sizeof(float)); + + // When re-called from load_imatrix() with add set, this will already be created. + if (e.values.empty()) { + e.values.resize(nval, 0); + e.counts.resize(nval, 0); + } + + std::vector tmp(nval); + in.read((char*)tmp.data(), nval*sizeof(float)); if (in.fail()) { printf("%s: failed reading data for entry %d\n",__func__,i); imatrix_data = {}; return false; } - e.ncall = ncall; + + // Recreate the state as expected by save_imatrix(), and corerct for weighted sum. + for (int i = 0; i < nval; i++) { + e.values[i] += tmp[i]; + e.counts[i] += ncall; + } + e.ncall += ncall; + } return true; } diff --git a/examples/llava/README.md b/examples/llava/README.md index d4810d42e..4fb0cf381 100644 --- a/examples/llava/README.md +++ b/examples/llava/README.md @@ -56,7 +56,7 @@ python ./examples/llava/convert-image-encoder-to-gguf.py -m ../clip-vit-large-pa python ./convert.py ../llava-v1.5-7b --skip-unknown ``` -Now both the LLaMA part and the image encoder is in the `llava-v1.5-7b` directory. +Now both the LLaMA part and the image encoder are in the `llava-v1.5-7b` directory. ## LLaVA 1.6 gguf conversion 1) First clone a LLaVA 1.6 model: diff --git a/examples/main/README.md b/examples/main/README.md index e7a38743c..97e2ae4c2 100644 --- a/examples/main/README.md +++ b/examples/main/README.md @@ -143,7 +143,7 @@ The `--ctx-size` option allows you to set the size of the prompt context used by ### Extended Context Size -Some fine-tuned models have extended the context length by scaling RoPE. For example, if the original pre-trained model have a context length (max sequence length) of 4096 (4k) and the fine-tuned model have 32k. That is a scaling factor of 8, and should work by setting the above `--ctx-size` to 32768 (32k) and `--rope-scale` to 8. +Some fine-tuned models have extended the context length by scaling RoPE. For example, if the original pre-trained model has a context length (max sequence length) of 4096 (4k) and the fine-tuned model has 32k. That is a scaling factor of 8, and should work by setting the above `--ctx-size` to 32768 (32k) and `--rope-scale` to 8. - `--rope-scale N`: Where N is the linear scaling factor used by the fine-tuned model. @@ -286,7 +286,7 @@ These options help improve the performance and memory usage of the LLaMA models. - `--numa distribute`: Pin an equal proportion of the threads to the cores on each NUMA node. This will spread the load amongst all cores on the system, utilitizing all memory channels at the expense of potentially requiring memory to travel over the slow links between nodes. - `--numa isolate`: Pin all threads to the NUMA node that the program starts on. This limits the number of cores and amount of memory that can be used, but guarantees all memory access remains local to the NUMA node. -- `--numa numactl`: Pin threads to the CPUMAP that is passed to the program by starting it with the numactl utility. This is the most flexible mode, and allow arbitraty core usage patterns, for example a map that uses all the cores on one NUMA nodes, and just enough cores on a second node to saturate the inter-node memory bus. +- `--numa numactl`: Pin threads to the CPUMAP that is passed to the program by starting it with the numactl utility. This is the most flexible mode, and allow arbitrary core usage patterns, for example a map that uses all the cores on one NUMA nodes, and just enough cores on a second node to saturate the inter-node memory bus. These flags attempt optimizations that help on some systems with non-uniform memory access. This currently consists of one of the above strategies, and disabling prefetch and readahead for mmap. The latter causes mapped pages to be faulted in on first access instead of all at once, and in combination with pinning threads to NUMA nodes, more of the pages end up on the NUMA node where they are used. Note that if the model is already in the system page cache, for example because of a previous run without this option, this will have little effect unless you drop the page cache first. This can be done by rebooting the system or on Linux by writing '3' to '/proc/sys/vm/drop_caches' as root. diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 3862d731b..1e4c6caf7 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -797,7 +797,7 @@ int main(int argc, char ** argv) { // deal with end of generation tokens in interactive mode if (llama_token_is_eog(model, llama_sampling_last(ctx_sampling))) { - LOG("found EOS token\n"); + LOG("found an EOG token\n"); if (params.interactive) { if (!params.antiprompt.empty()) { diff --git a/examples/server/README.md b/examples/server/README.md index b96a4444a..a7c3f0b5f 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -62,6 +62,18 @@ page cache before using this. See https://github.com/ggerganov/llama.cpp/issues/ - `--chat-template JINJA_TEMPLATE`: Set custom jinja chat template. This parameter accepts a string, not a file name. Default: template taken from model's metadata. We only support [some pre-defined templates](https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template) - `--log-disable`: Output logs to stdout only, not to `llama.log`. Default: enabled - `--log-format FORMAT`: Define the log output to FORMAT: json or text Default: `json` +- `--rope-scaling` : RoPE scaling method. Defaults to linear unless otherwise specified by the model. Options are `none`, `linear`, `yarn` +- `--rope-freq-base N` : RoPE frequency base (default: loaded from model) +- `--rope-freq-scale N`: RoPE frequency scaling factor, expands context by a factor of 1/N (e.g. 0.25) +- `--yarn-ext-factor N` : YaRN: extrapolation mix factor (Default: 1.0, 0.0 = full interpolation) +- `--yarn-attn-factor N` : YaRN: scale sqrt(t) or attention magnitude (default: 1.0) +- `--yarn-beta-slow N`: YaRN: High correction dim or alpha (default: 1.0) +- `--yarn-beta-fast N`: YaRN: low correction dim or beta (default: 32.0) +- `--pooling` : Pooling type for embeddings, use model default if unspecified. Options are `none`, `mean`, `cls` +- `-dt N`, `--defrag-thold N`: KV cache defragmentation threshold (default: -1.0, < 0 = disabled) +- `-fa`, `--flash-attn` : enable flash attention (default: disabled). +- `-ctk TYPE`, `--cache-type-k TYPE` : KV cache data type for K (default: `f16`, options `f32`, `f16`, `q8_0`, `q4_0`, `q4_1`, `iq4_nl`, `q5_0`, or `q5_1`) +- `-ctv TYPE`, `--cache-type-v TYPE` : KV cache type for V (default `f16`, see `-ctk` for options) **If compiled with `LLAMA_SERVER_SSL=ON`** - `--ssl-key-file FNAME`: path to file a PEM-encoded SSL private key @@ -260,7 +272,7 @@ node index.js `logit_bias`: Modify the likelihood of a token appearing in the generated text completion. For example, use `"logit_bias": [[15043,1.0]]` to increase the likelihood of the token 'Hello', or `"logit_bias": [[15043,-1.0]]` to decrease its likelihood. Setting the value to false, `"logit_bias": [[15043,false]]` ensures that the token `Hello` is never produced. The tokens can also be represented as strings, e.g. `[["Hello, World!",-0.5]]` will reduce the likelihood of all the individual tokens that represent the string `Hello, World!`, just like the `presence_penalty` does. Default: `[]` - `n_probs`: If greater than 0, the response also contains the probabilities of top N tokens for each generated token. Default: `0` + `n_probs`: If greater than 0, the response also contains the probabilities of top N tokens for each generated token given the sampling settings. Note that for temperature < 0 the tokens are sampled greedily but token probabilities are still being calculated via a simple softmax of the logits without considering any other sampler settings. Default: `0` `min_keep`: If greater than 0, force samplers to return N possible tokens at minimum. Default: `0` diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 97bf65a28..0c6e22e72 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2267,17 +2267,31 @@ struct server_context { llama_token_data_array cur_p = { slot.ctx_sampling->cur.data(), slot.ctx_sampling->cur.size(), false }; result.tok = id; - const int32_t n_probs = slot.sparams.n_probs; - if (slot.sparams.temp <= 0 && n_probs > 0) { - // for llama_sample_token_greedy we need to sort candidates - llama_sample_softmax(ctx, &cur_p); - } + const size_t n_probs = std::min(cur_p.size, (size_t) slot.sparams.n_probs); + if (n_probs > 0) { + const size_t n_considered = slot.ctx_sampling->n_considered; - for (size_t i = 0; i < std::min(cur_p.size, (size_t) n_probs); ++i) { - result.probs.push_back({ - cur_p.data[i].id, - cur_p.data[i].p - }); + // Make sure at least n_probs top tokens are at the front of the vector: + if (slot.sparams.temp == 0.0f && n_probs > n_considered) { + llama_sample_top_k(ctx, &cur_p, n_probs, 0); + } + + if (slot.sparams.temp == 0.0f) { + // With greedy sampling the probabilities have possibly not been calculated. + for (size_t i = 0; i < n_probs; ++i) { + result.probs.push_back({ + cur_p.data[i].id, + i == 0 ? 1.0f : 0.0f + }); + } + } else { + for (size_t i = 0; i < n_probs; ++i) { + result.probs.push_back({ + cur_p.data[i].id, + i >= n_considered ? 0.0f : cur_p.data[i].p // Tokens filtered out due to e.g. top_k have 0 probability. + }); + } + } } if (!process_token(result, slot)) { diff --git a/examples/sycl/README.md b/examples/sycl/README.md index b46f17f39..c589c2d3a 100644 --- a/examples/sycl/README.md +++ b/examples/sycl/README.md @@ -1,6 +1,6 @@ # llama.cpp/example/sycl -This example program provide the tools for llama.cpp for SYCL on Intel GPU. +This example program provides the tools for llama.cpp for SYCL on Intel GPU. ## Tool diff --git a/ggml-cuda.cu b/ggml-cuda.cu index ba65d4747..199b74402 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -115,7 +115,7 @@ static ggml_cuda_device_info ggml_cuda_init() { for (int id = 0; id < info.device_count; ++id) { int device_vmm = 0; -#if !defined(GGML_USE_HIPBLAS) +#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) CUdevice device; CU_CHECK(cuDeviceGet(&device, id)); CU_CHECK(cuDeviceGetAttribute(&device_vmm, CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED, device)); @@ -259,7 +259,7 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool { }; // pool with virtual memory -#if !defined(GGML_USE_HIPBLAS) +#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) struct ggml_cuda_pool_vmm : public ggml_cuda_pool { static const size_t CUDA_POOL_VMM_MAX_SIZE = 1ull << 35; // 32 GB @@ -356,7 +356,7 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool { #endif // !defined(GGML_USE_HIPBLAS) std::unique_ptr ggml_backend_cuda_context::new_pool_for_device(int device) { -#if !defined(GGML_USE_HIPBLAS) +#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) if (ggml_cuda_info().devices[device].vmm) { return std::unique_ptr(new ggml_cuda_pool_vmm(device)); } diff --git a/llama.cpp b/llama.cpp index 9df4f864d..cb4084dc5 100644 --- a/llama.cpp +++ b/llama.cpp @@ -4443,6 +4443,9 @@ static void llm_load_vocab( } else if ( tokenizer_pre == "command-r") { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_COMMAND_R; + } else if ( + tokenizer_pre == "olmo") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_OLMO; } else { throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str())); } @@ -12539,6 +12542,7 @@ struct llm_tokenizer_bpe { }); break; case LLAMA_VOCAB_PRE_TYPE_GPT2: + case LLAMA_VOCAB_PRE_TYPE_OLMO: word_collection = unicode_regex_split(text, { "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", }); diff --git a/llama.h b/llama.h index 0fd75fc79..ffbd28c85 100644 --- a/llama.h +++ b/llama.h @@ -81,6 +81,7 @@ extern "C" { LLAMA_VOCAB_PRE_TYPE_GPT2 = 7, LLAMA_VOCAB_PRE_TYPE_REFACT = 8, LLAMA_VOCAB_PRE_TYPE_COMMAND_R = 9, + LLAMA_VOCAB_PRE_TYPE_OLMO = 10, }; // note: these values should be synchronized with ggml_rope