mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-05-10 04:00:53 +00:00
Merge branch 'upstream' into concedo_experimental
# Conflicts: # README.md # docs/build-s390x.md # examples/llama.vim # ggml/src/ggml-cann/aclnn_ops.cpp # ggml/src/ggml-cann/common.h # scripts/compare-llama-bench.py # src/CMakeLists.txt # tests/test-backend-ops.cpp # tools/llama-bench/README.md # tools/llama-bench/llama-bench.cpp # tools/server/README.md
This commit is contained in:
commit
8b8396c30c
48 changed files with 1595 additions and 1411 deletions
43
.github/workflows/build-riscv-native.yml
vendored
43
.github/workflows/build-riscv-native.yml
vendored
|
|
@ -1,43 +0,0 @@
|
|||
name: Build on RISCV Linux Machine by Cloud-V
|
||||
on:
|
||||
workflow_dispatch:
|
||||
workflow_call:
|
||||
|
||||
jobs:
|
||||
bianbu-riscv64-native: # Bianbu 2.2
|
||||
runs-on: self-hosted
|
||||
|
||||
steps:
|
||||
- name: Install prerequisites
|
||||
run: |
|
||||
sudo apt-get update || true
|
||||
sudo apt-get install -y libatomic1
|
||||
- uses: actions/checkout@v4
|
||||
- name: Setup Riscv
|
||||
run: |
|
||||
sudo apt-get update || true
|
||||
sudo apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
gcc-14-riscv64-linux-gnu \
|
||||
g++-14-riscv64-linux-gnu \
|
||||
cmake
|
||||
|
||||
- name: Build
|
||||
run: |
|
||||
cmake -B build -DLLAMA_CURL=OFF \
|
||||
-DCMAKE_BUILD_TYPE=Release \
|
||||
-DGGML_OPENMP=OFF \
|
||||
-DLLAMA_BUILD_EXAMPLES=ON \
|
||||
-DLLAMA_BUILD_TOOLS=ON \
|
||||
-DLLAMA_BUILD_TESTS=OFF \
|
||||
-DCMAKE_SYSTEM_NAME=Linux \
|
||||
-DCMAKE_SYSTEM_PROCESSOR=riscv64 \
|
||||
-DCMAKE_C_COMPILER=riscv64-linux-gnu-gcc-14 \
|
||||
-DCMAKE_CXX_COMPILER=riscv64-linux-gnu-g++-14 \
|
||||
-DCMAKE_POSITION_INDEPENDENT_CODE=ON \
|
||||
-DCMAKE_FIND_ROOT_PATH=/usr/lib/riscv64-linux-gnu \
|
||||
-DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \
|
||||
-DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \
|
||||
-DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH
|
||||
|
||||
cmake --build build --config Release -j $(nproc)
|
||||
2
Makefile
2
Makefile
|
|
@ -688,7 +688,7 @@ embeddings_default.o: otherarch/embeddings_adapter.cpp
|
|||
$(CXX) $(CXXFLAGS) -c $< -o $@
|
||||
|
||||
# idiotic "for easier compilation"
|
||||
GPTTYPE_ADAPTER = gpttype_adapter.cpp otherarch/llama_v2.cpp otherarch/llama_v3.cpp src/llama.cpp src/llama-impl.cpp src/llama-chat.cpp src/llama-mmap.cpp src/llama-context.cpp src/llama-adapter.cpp src/llama-arch.cpp src/llama-batch.cpp src/llama-vocab.cpp src/llama-grammar.cpp src/llama-sampling.cpp src/llama-kv-cache-unified.cpp src/llama-kv-cache-unified-iswa.cpp src/llama-memory-hybrid.cpp src/llama-memory-recurrent.cpp src/llama-model-loader.cpp src/llama-model.cpp src/llama-quant.cpp src/llama-hparams.cpp otherarch/gptj_v1.cpp otherarch/gptj_v2.cpp otherarch/gptj_v3.cpp otherarch/gpt2_v1.cpp otherarch/gpt2_v2.cpp otherarch/gpt2_v3.cpp otherarch/rwkv_v2.cpp otherarch/rwkv_v3.cpp otherarch/neox_v2.cpp otherarch/neox_v3.cpp otherarch/mpt_v3.cpp ggml/include/ggml.h ggml/include/ggml-cpu.h ggml/include/ggml-cuda.h include/llama.h otherarch/llama-util.h
|
||||
GPTTYPE_ADAPTER = gpttype_adapter.cpp otherarch/llama_v2.cpp otherarch/llama_v3.cpp src/llama.cpp src/llama-impl.cpp src/llama-chat.cpp src/llama-mmap.cpp src/llama-context.cpp src/llama-adapter.cpp src/llama-arch.cpp src/llama-batch.cpp src/llama-vocab.cpp src/llama-grammar.cpp src/llama-sampling.cpp src/llama-kv-cache.cpp src/llama-kv-cache-iswa.cpp src/llama-memory-hybrid.cpp src/llama-memory-recurrent.cpp src/llama-model-loader.cpp src/llama-model.cpp src/llama-quant.cpp src/llama-hparams.cpp otherarch/gptj_v1.cpp otherarch/gptj_v2.cpp otherarch/gptj_v3.cpp otherarch/gpt2_v1.cpp otherarch/gpt2_v2.cpp otherarch/gpt2_v3.cpp otherarch/rwkv_v2.cpp otherarch/rwkv_v3.cpp otherarch/neox_v2.cpp otherarch/neox_v3.cpp otherarch/mpt_v3.cpp ggml/include/ggml.h ggml/include/ggml-cpu.h ggml/include/ggml-cuda.h include/llama.h otherarch/llama-util.h
|
||||
gpttype_adapter_failsafe.o: $(GPTTYPE_ADAPTER)
|
||||
$(CXX) $(CXXFLAGS) $(FAILSAFE_FLAGS) -c $< -o $@
|
||||
gpttype_adapter.o: $(GPTTYPE_ADAPTER)
|
||||
|
|
|
|||
|
|
@ -1757,7 +1757,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
[](common_params & params) {
|
||||
params.warmup = false;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL}));
|
||||
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL, LLAMA_EXAMPLE_PERPLEXITY}));
|
||||
add_opt(common_arg(
|
||||
{"--spm-infill"},
|
||||
string_format(
|
||||
|
|
@ -2256,9 +2256,11 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
).set_examples({LLAMA_EXAMPLE_PERPLEXITY}));
|
||||
add_opt(common_arg(
|
||||
{"-dt", "--defrag-thold"}, "N",
|
||||
string_format("KV cache defragmentation threshold (default: %.1f, < 0 - disabled)", (double)params.defrag_thold),
|
||||
string_format("KV cache defragmentation threshold (DEPRECATED)"),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.defrag_thold = std::stof(value);
|
||||
GGML_UNUSED(params);
|
||||
GGML_UNUSED(value);
|
||||
LOG_WRN("DEPRECATED: --defrag-thold is deprecated and no longer necessary to specify\n");
|
||||
}
|
||||
).set_env("LLAMA_ARG_DEFRAG_THOLD"));
|
||||
add_opt(common_arg(
|
||||
|
|
|
|||
|
|
@ -1361,6 +1361,26 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
|
|||
"<|end|>",
|
||||
};
|
||||
|
||||
if (!inputs.json_schema.is_null()) {
|
||||
data.grammar_lazy = false;
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
auto schema = inputs.json_schema;
|
||||
builder.resolve_refs(schema);
|
||||
|
||||
auto not_end = builder.add_rule("not-end",
|
||||
"[^<] | \"<\" [^|] | \"<|\" [^e] | \"<|e\" [^n] | \"<|en\" [^d] | \"<|end\" [^|] | \"<|end|\" [^>]");
|
||||
auto analysis = builder.add_rule("analysis",
|
||||
"\"<|channel|>analysis<|message|>\" ( " + not_end + " )* \"<|end|>\"");
|
||||
auto constraint = builder.add_rule("constraint", "\"<|constrain|>\"? [a-zA-Z0-9_-]+");
|
||||
auto final = builder.add_rule("final",
|
||||
"\"<|channel|>final\" ( \" \" " + constraint + " )? \"<|message|>\" " +
|
||||
builder.add_schema("response", schema)
|
||||
);
|
||||
|
||||
builder.add_rule("root", "( " + analysis + " \"<|start|>assistant\" )? " + final);
|
||||
});
|
||||
}
|
||||
|
||||
if (inputs.tools.is_array() && !inputs.tools.empty()) {
|
||||
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
|
|
@ -2121,7 +2141,7 @@ static common_chat_params common_chat_templates_apply_jinja(
|
|||
}
|
||||
|
||||
// GPT-OSS
|
||||
if (src.find("<|channel|>") != std::string::npos && params.json_schema.is_null()) {
|
||||
if (src.find("<|channel|>") != std::string::npos) {
|
||||
return common_chat_params_init_gpt_oss(tmpl, params);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1160,7 +1160,6 @@ struct llama_context_params common_context_params_to_llama(const common_params &
|
|||
cparams.yarn_orig_ctx = params.yarn_orig_ctx;
|
||||
cparams.pooling_type = params.pooling_type;
|
||||
cparams.attention_type = params.attention_type;
|
||||
cparams.defrag_thold = params.defrag_thold;
|
||||
cparams.cb_eval = params.cb_eval;
|
||||
cparams.cb_eval_user_data = params.cb_eval_user_data;
|
||||
cparams.offload_kqv = !params.no_kv_offload;
|
||||
|
|
|
|||
|
|
@ -284,7 +284,6 @@ struct common_params {
|
|||
float yarn_beta_fast = 32.0f; // YaRN low correction dim
|
||||
float yarn_beta_slow = 1.0f; // YaRN high correction dim
|
||||
int32_t yarn_orig_ctx = 0; // YaRN original context length
|
||||
float defrag_thold = 0.1f; // KV cache defragmentation threshold
|
||||
|
||||
// offload params
|
||||
std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ The motivation for having this is that the conversion process can often be an
|
|||
iterative process, where the original model is inspected, converted, updates
|
||||
made to llama.cpp, converted again, etc. Once the model has been converted it
|
||||
needs to be verified against the original model, and then optionally quantified,
|
||||
and is some cases perplexity checked of the quantized model. And finally the
|
||||
and in some cases perplexity checked of the quantized model. And finally the
|
||||
model/models need to the ggml-org on Hugging Face. This tool/example tries to
|
||||
help with this process.
|
||||
|
||||
|
|
@ -62,7 +62,7 @@ Command line arguments take precedence over environment variables when both are
|
|||
|
||||
In cases where the transformer implementation for the model has not been released
|
||||
yet it is possible to set the environment variable `UNRELEASED_MODEL_NAME` which
|
||||
will the cause the transformer implementation to be loaded explicitely and not
|
||||
will then cause the transformer implementation to be loaded explicitely and not
|
||||
use AutoModelForCausalLM:
|
||||
```
|
||||
export UNRELEASED_MODEL_NAME=SomeNewModel
|
||||
|
|
@ -87,7 +87,7 @@ from the converted model.
|
|||
# Or using command line argument
|
||||
(venv) $ make causal-run-original-model MODEL_PATH=~/work/ai/models/some_model
|
||||
```
|
||||
This command will save two file to the `data` directory, one is a binary file
|
||||
This command will save two files to the `data` directory, one is a binary file
|
||||
containing logits which will be used for comparison with the converted model
|
||||
later, and the other is a text file which allows for manual visual inspection.
|
||||
|
||||
|
|
@ -128,11 +128,11 @@ Quantized model saved to: /path/to/quantized/model-Q8_0.gguf
|
|||
Export the quantized model path to QUANTIZED_MODEL variable in your environment
|
||||
```
|
||||
This will show the path to the quantized model in the terminal, which can then
|
||||
be used set the `QUANTIZED_MODEL` environment variable:
|
||||
be used to set the `QUANTIZED_MODEL` environment variable:
|
||||
```console
|
||||
export QUANTIZED_MODEL=/path/to/quantized/model-Q8_0.gguf
|
||||
```
|
||||
The the quantized model can be run using the following command:
|
||||
Then the quantized model can be run using the following command:
|
||||
```console
|
||||
(venv) $ make causal-run-quantized-model
|
||||
```
|
||||
|
|
@ -229,11 +229,11 @@ Quantized model saved to: /path/to/quantized/model-Q8_0.gguf
|
|||
Export the quantized model path to QUANTIZED_EMBEDDING_MODEL variable in your environment
|
||||
```
|
||||
This will show the path to the quantized model in the terminal, which can then
|
||||
be used set the `QUANTIZED_EMBEDDING_MODEL` environment variable:
|
||||
be used to set the `QUANTIZED_EMBEDDING_MODEL` environment variable:
|
||||
```console
|
||||
export QUANTIZED_EMBEDDING_MODEL=/path/to/quantized/model-Q8_0.gguf
|
||||
```
|
||||
The the quantized model can be run using the following command:
|
||||
Then the quantized model can be run using the following command:
|
||||
```console
|
||||
(venv) $ make embedding-run-quantized-model
|
||||
```
|
||||
|
|
@ -246,7 +246,7 @@ token/logits file:
|
|||
```console
|
||||
(venv) $ make perplexity-run QUANTIZED_MODEL=~/path/to/quantized/model.gguf
|
||||
```
|
||||
This will use the wikitext dataset to run the perplexity evaluation and and
|
||||
This will use the wikitext dataset to run the perplexity evaluation and
|
||||
output the perplexity score to the terminal. This value can then be compared
|
||||
with the perplexity score of the unquantized model.
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
--extra-index-url https://download.pytorch.org/whl/cpu
|
||||
torch~=2.6.0
|
||||
torchvision~=0.21.0
|
||||
transformers~=4.55.0
|
||||
|
|
|
|||
|
|
@ -518,6 +518,7 @@ extern "C" {
|
|||
GGML_OP_IM2COL,
|
||||
GGML_OP_IM2COL_BACK,
|
||||
GGML_OP_CONV_2D,
|
||||
GGML_OP_CONV_3D,
|
||||
GGML_OP_CONV_2D_DW,
|
||||
GGML_OP_CONV_TRANSPOSE_2D,
|
||||
GGML_OP_POOL_1D,
|
||||
|
|
@ -1965,6 +1966,23 @@ extern "C" {
|
|||
int d0, // dilation dimension 0
|
||||
int d1); // dilation dimension 1
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_conv_3d(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a, // kernel [KW, KH, KD, IC * OC]
|
||||
struct ggml_tensor * b, // input [W, H, D, C * N]
|
||||
int s0, // stride
|
||||
int s1,
|
||||
int s2,
|
||||
int p0, // padding
|
||||
int p1,
|
||||
int p2,
|
||||
int d0, // dilation
|
||||
int d1,
|
||||
int d2,
|
||||
int n_channels,
|
||||
int n_batch,
|
||||
int n_channels_out);
|
||||
|
||||
enum ggml_op_pool {
|
||||
GGML_OP_POOL_MAX,
|
||||
GGML_OP_POOL_AVG,
|
||||
|
|
|
|||
|
|
@ -1361,15 +1361,15 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
|
|||
std::vector<int32_t> ids;
|
||||
std::vector<ggml_bitset_t> used_ids;
|
||||
|
||||
for (int i = 0; i < sched->n_splits; i++) {
|
||||
struct ggml_backend_sched_split * split = &splits[i];
|
||||
for (int split_id = 0; split_id < sched->n_splits; split_id++) {
|
||||
struct ggml_backend_sched_split * split = &splits[split_id];
|
||||
int split_backend_id = split->backend_id;
|
||||
ggml_backend_t split_backend = sched->backends[split_backend_id];
|
||||
|
||||
// copy the input tensors to the split backend
|
||||
for (int j = 0; j < split->n_inputs; j++) {
|
||||
ggml_backend_t input_backend = ggml_backend_sched_get_tensor_backend(sched, split->inputs[j]);
|
||||
struct ggml_tensor * input = split->inputs[j];
|
||||
for (int input_id = 0; input_id < split->n_inputs; input_id++) {
|
||||
ggml_backend_t input_backend = ggml_backend_sched_get_tensor_backend(sched, split->inputs[input_id]);
|
||||
struct ggml_tensor * input = split->inputs[input_id];
|
||||
struct ggml_tensor * input_cpy = tensor_copy(input, split_backend_id, sched->cur_copy);
|
||||
|
||||
if (input->flags & GGML_TENSOR_FLAG_INPUT) {
|
||||
|
|
@ -1404,10 +1404,22 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
|
|||
|
||||
// get the ids
|
||||
ggml_tensor * ids_tensor = node->src[2];
|
||||
ggml_backend_t ids_backend = split_backend;
|
||||
|
||||
// if the ids tensor is also an input of the split, it may not have been copied yet to the split backend
|
||||
// in that case, we use the original ids tensor
|
||||
for (int i = input_id + 1; i < split->n_inputs; i++) {
|
||||
if (ids_tensor == tensor_copy(split->inputs[i], split_backend_id, sched->cur_copy)) {
|
||||
ids_tensor = split->inputs[i];
|
||||
ids_backend = ggml_backend_sched_get_tensor_backend(sched, split->inputs[i]);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (ids_tensor != prev_ids_tensor) {
|
||||
ids.resize(ggml_nbytes(ids_tensor) / sizeof(int32_t));
|
||||
ggml_backend_tensor_get_async(split_backend, ids_tensor, ids.data(), 0, ggml_nbytes(ids_tensor));
|
||||
ggml_backend_synchronize(split_backend);
|
||||
ggml_backend_tensor_get_async(ids_backend, ids_tensor, ids.data(), 0, ggml_nbytes(ids_tensor));
|
||||
ggml_backend_synchronize(ids_backend);
|
||||
|
||||
// find the used experts
|
||||
used_ids.clear();
|
||||
|
|
@ -1415,6 +1427,7 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
|
|||
for (int64_t i1 = 0; i1 < ids_tensor->ne[1]; i1++) {
|
||||
for (int64_t i0 = 0; i0 < ids_tensor->ne[0]; i0++) {
|
||||
int32_t id = ids[i1 * ids_tensor->nb[1]/sizeof(int32_t) + i0 * ids_tensor->nb[0]/sizeof(int32_t)];
|
||||
GGML_ASSERT(id >= 0 && id < n_expert);
|
||||
ggml_bitset_set(used_ids.data(), id);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -150,8 +150,6 @@
|
|||
#elif defined(__s390x__)
|
||||
// quants.c
|
||||
#define quantize_row_q8_K_generic quantize_row_q8_K
|
||||
#define ggml_vec_dot_q5_0_q8_0_generic ggml_vec_dot_q5_0_q8_0
|
||||
#define ggml_vec_dot_q5_1_q8_1_generic ggml_vec_dot_q5_1_q8_1
|
||||
#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
|
||||
#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
|
||||
#define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K
|
||||
|
|
|
|||
|
|
@ -23,6 +23,27 @@
|
|||
|
||||
#define UNUSED GGML_UNUSED
|
||||
|
||||
#if defined(__VXE__) || defined(__VXE2__)
|
||||
#define B1(c,s,n) 0x ## n ## c , 0x ## n ## s
|
||||
#define B2(c,s,n) B1(c,s,n ## c), B1(c,s,n ## s)
|
||||
#define B3(c,s,n) B2(c,s,n ## c), B2(c,s,n ## s)
|
||||
#define B4(c,s,n) B3(c,s,n ## c), B3(c,s,n ## s)
|
||||
#define B5(c,s,n) B4(c,s,n ## c), B4(c,s,n ## s)
|
||||
#define B6(c,s,n) B5(c,s,n ## c), B5(c,s,n ## s)
|
||||
#define B7(c,s,n) B6(c,s,n ## c), B6(c,s,n ## s)
|
||||
#define B8(c,s ) B7(c,s, c), B7(c,s, s)
|
||||
|
||||
// precomputed tables for expanding 8bits to 8 bytes:
|
||||
static const __attribute__((aligned(16))) uint64_t table_b2b_0[1 << 8] = { B8(00, 10) }; // ( b ) << 4
|
||||
static const __attribute__((aligned(16))) uint64_t table_b2b_1[1 << 8] = { B8(10, 00) }; // (!b) << 4
|
||||
|
||||
// permute mask for byteswapping
|
||||
static const uint8x16_t v_kperm = (const uint8x16_t){
|
||||
7, 6, 5, 4, 3, 2, 1, 0,
|
||||
15, 14, 13, 12, 11, 10, 9, 8
|
||||
};
|
||||
#endif
|
||||
|
||||
void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
|
||||
assert(QK8_0 == 32);
|
||||
assert(k % QK8_0 == 0);
|
||||
|
|
@ -241,6 +262,301 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
|||
#endif
|
||||
}
|
||||
|
||||
void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
const int qk = QK8_0;
|
||||
const int nb = n / qk;
|
||||
|
||||
assert(n % qk == 0);
|
||||
assert(qk == QK5_0);
|
||||
assert(nrc == 1);
|
||||
UNUSED(nrc);
|
||||
UNUSED(bx);
|
||||
UNUSED(by);
|
||||
UNUSED(bs);
|
||||
|
||||
const block_q5_0 * GGML_RESTRICT x = vx;
|
||||
const block_q8_0 * GGML_RESTRICT y = vy;
|
||||
|
||||
int ib = 0;
|
||||
float sumf = 0.0f;
|
||||
|
||||
#if defined(__VXE__) || defined(__VXE2__)
|
||||
float32x4_t v_sum0 = vec_splats(0.0f);
|
||||
float32x4_t v_sum1 = vec_splats(0.0f);
|
||||
|
||||
uint32_t qh0, qh1;
|
||||
uint64_t tmp0[4], tmp1[4];
|
||||
|
||||
const uint8x16_t v_m = vec_splats((uint8_t)0x0F);
|
||||
|
||||
#pragma GCC unroll 4
|
||||
for (; ib + 1 < nb; ib += 2) {
|
||||
const block_q5_0 * GGML_RESTRICT x0 = &x[ib + 0];
|
||||
const block_q5_0 * GGML_RESTRICT x1 = &x[ib + 1];
|
||||
const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0];
|
||||
const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1];
|
||||
|
||||
memcpy(&qh0, x0->qh, sizeof(qh0));
|
||||
memcpy(&qh1, x1->qh, sizeof(qh1));
|
||||
|
||||
tmp0[0] = table_b2b_1[(qh0 >> 0) & 0xFF];
|
||||
tmp0[1] = table_b2b_1[(qh0 >> 8) & 0xFF];
|
||||
tmp0[2] = table_b2b_1[(qh0 >> 16) & 0xFF];
|
||||
tmp0[3] = table_b2b_1[(qh0 >> 24) ];
|
||||
|
||||
tmp1[0] = table_b2b_1[(qh1 >> 0) & 0xFF];
|
||||
tmp1[1] = table_b2b_1[(qh1 >> 8) & 0xFF];
|
||||
tmp1[2] = table_b2b_1[(qh1 >> 16) & 0xFF];
|
||||
tmp1[3] = table_b2b_1[(qh1 >> 24) ];
|
||||
|
||||
int8x16_t v_qh0l = vec_xl(0, (const int8_t *)(tmp0 + 0));
|
||||
int8x16_t v_qh0h = vec_xl(0, (const int8_t *)(tmp0 + 2));
|
||||
int8x16_t v_qh1l = vec_xl(0, (const int8_t *)(tmp1 + 0));
|
||||
int8x16_t v_qh1h = vec_xl(0, (const int8_t *)(tmp1 + 2));
|
||||
|
||||
// required for fixing the byteorder
|
||||
v_qh0l = vec_perm(v_qh0l, v_qh0l, v_kperm);
|
||||
v_qh0h = vec_perm(v_qh0h, v_qh0h, v_kperm);
|
||||
v_qh1l = vec_perm(v_qh1l, v_qh1l, v_kperm);
|
||||
v_qh1h = vec_perm(v_qh1h, v_qh1h, v_kperm);
|
||||
|
||||
const uint8x16_t v_x0 = vec_xl(0, (const uint8_t *)x0->qs);
|
||||
const uint8x16_t v_x1 = vec_xl(0, (const uint8_t *)x1->qs);
|
||||
|
||||
int8x16_t v_x0l = (int8x16_t)vec_and(v_x0, v_m);
|
||||
int8x16_t v_x0h = (int8x16_t)vec_sr(v_x0, 4);
|
||||
int8x16_t v_x1l = (int8x16_t)vec_and(v_x1, v_m);
|
||||
int8x16_t v_x1h = (int8x16_t)vec_sr(v_x1, 4);
|
||||
|
||||
const int8x16_t v_x0lf = vec_sub(v_x0l, v_qh0l);
|
||||
const int8x16_t v_x0hf = vec_sub(v_x0h, v_qh0h);
|
||||
const int8x16_t v_x1lf = vec_sub(v_x1l, v_qh1l);
|
||||
const int8x16_t v_x1hf = vec_sub(v_x1h, v_qh1h);
|
||||
|
||||
const int8x16_t v_y0l = vec_xl(0, (const int8_t *)y0->qs);
|
||||
const int8x16_t v_y0h = vec_xl(QK8_0/2, (const int8_t *)y0->qs);
|
||||
const int8x16_t v_y1l = vec_xl(0, (const int8_t *)y1->qs);
|
||||
const int8x16_t v_y1h = vec_xl(QK8_0/2, (const int8_t *)y1->qs);
|
||||
|
||||
const int32x4_t v_xy0 = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_x0lf, v_y0l), v_x0hf, v_y0h);
|
||||
const int32x4_t v_xy1 = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_x1lf, v_y1l), v_x1hf, v_y1h);
|
||||
|
||||
const float32x4_t v_xy0f = vec_float(v_xy0);
|
||||
const float32x4_t v_xy1f = vec_float(v_xy1);
|
||||
|
||||
const float32x4_t v_d0 = vec_splats(GGML_CPU_FP16_TO_FP32(x0->d) * GGML_CPU_FP16_TO_FP32(y0->d));
|
||||
const float32x4_t v_d1 = vec_splats(GGML_CPU_FP16_TO_FP32(x1->d) * GGML_CPU_FP16_TO_FP32(y1->d));
|
||||
|
||||
v_sum0 = vec_madd(v_xy0f, v_d0, v_sum0);
|
||||
v_sum1 = vec_madd(v_xy1f, v_d1, v_sum1);
|
||||
}
|
||||
|
||||
sumf += vec_hsum(v_sum0) + vec_hsum(v_sum1);
|
||||
|
||||
#pragma GCC unroll 4
|
||||
for (; ib < nb; ++ib) {
|
||||
const block_q5_0 * GGML_RESTRICT x0 = &x[ib];
|
||||
const block_q8_0 * GGML_RESTRICT y0 = &y[ib];
|
||||
|
||||
uint32_t qh;
|
||||
memcpy(&qh, x0->qh, sizeof(qh));
|
||||
|
||||
uint64_t tmp[4];
|
||||
tmp[0] = table_b2b_1[(qh >> 0) & 0xFF];
|
||||
tmp[1] = table_b2b_1[(qh >> 8) & 0xFF];
|
||||
tmp[2] = table_b2b_1[(qh >> 16) & 0xFF];
|
||||
tmp[3] = table_b2b_1[(qh >> 24) ];
|
||||
|
||||
int8x16_t v_qhl = vec_xl(0, (const int8_t *)(tmp + 0));
|
||||
int8x16_t v_qhh = vec_xl(0, (const int8_t *)(tmp + 2));
|
||||
|
||||
// required for fixing the byteorder
|
||||
v_qhl = vec_perm(v_qhl, v_qhl, v_kperm);
|
||||
v_qhh = vec_perm(v_qhh, v_qhh, v_kperm);
|
||||
|
||||
const uint8x16_t v_x = vec_xl(0, (const uint8_t *)x0->qs);
|
||||
int8x16_t v_xl = (int8x16_t)vec_and(v_x, v_m);
|
||||
int8x16_t v_xh = (int8x16_t)vec_sr(v_x, 4);
|
||||
|
||||
const int8x16_t v_xlf = vec_sub(v_xl, v_qhl);
|
||||
const int8x16_t v_xhf = vec_sub(v_xh, v_qhh);
|
||||
|
||||
const int8x16_t v_yl = vec_xl(0, (const int8_t *)y0->qs);
|
||||
const int8x16_t v_yh = vec_xl(QK8_0/2, (const int8_t *)y0->qs);
|
||||
|
||||
const int32x4_t v_xy = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_xlf, v_yl), v_xhf, v_yh);
|
||||
const float32x4_t v_xyf = vec_float(v_xy);
|
||||
|
||||
const float32x4_t v_d = vec_splats(GGML_CPU_FP16_TO_FP32(x0->d) * GGML_CPU_FP16_TO_FP32(y0->d));
|
||||
const float32x4_t v_acc = vec_madd(v_xyf, v_d, vec_splats(0.0f));
|
||||
|
||||
sumf += vec_hsum(v_acc);
|
||||
}
|
||||
|
||||
*s = sumf;
|
||||
#else
|
||||
UNUSED(nb);
|
||||
UNUSED(x);
|
||||
UNUSED(y);
|
||||
UNUSED(ib);
|
||||
UNUSED(sumf);
|
||||
ggml_vec_dot_q5_0_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
#endif
|
||||
}
|
||||
|
||||
void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
const int qk = QK8_1;
|
||||
const int nb = n / qk;
|
||||
|
||||
assert(n % qk == 0);
|
||||
assert(qk == QK5_1);
|
||||
assert(nrc == 1);
|
||||
UNUSED(nrc);
|
||||
UNUSED(bx);
|
||||
UNUSED(by);
|
||||
UNUSED(bs);
|
||||
|
||||
const block_q5_1 * GGML_RESTRICT x = vx;
|
||||
const block_q8_1 * GGML_RESTRICT y = vy;
|
||||
|
||||
int ib = 0;
|
||||
float sumf = 0.0f;
|
||||
|
||||
#if defined(__VXE__) || defined(__VXE2__)
|
||||
float32x4_t v_sum0 = vec_splats(0.0f);
|
||||
float32x4_t v_sum1 = vec_splats(0.0f);
|
||||
|
||||
float summs0 = 0.0f;
|
||||
float summs1 = 0.0f;
|
||||
|
||||
uint32_t qh0;
|
||||
uint32_t qh1;
|
||||
|
||||
uint64_t tmp0[4];
|
||||
uint64_t tmp1[4];
|
||||
|
||||
const uint8x16_t v_m = vec_splats((uint8_t)0x0F);
|
||||
|
||||
#pragma GCC unroll 4
|
||||
for (; ib + 1 < nb; ib += 2) {
|
||||
const block_q5_1 * GGML_RESTRICT x0 = &x[ib + 0];
|
||||
const block_q5_1 * GGML_RESTRICT x1 = &x[ib + 1];
|
||||
const block_q8_1 * GGML_RESTRICT y0 = &y[ib + 0];
|
||||
const block_q8_1 * GGML_RESTRICT y1 = &y[ib + 1];
|
||||
|
||||
summs0 += GGML_CPU_FP16_TO_FP32(x0->m) * GGML_CPU_FP16_TO_FP32(y0->s);
|
||||
summs1 += GGML_CPU_FP16_TO_FP32(x1->m) * GGML_CPU_FP16_TO_FP32(y1->s);
|
||||
|
||||
memcpy(&qh0, x0->qh, sizeof(qh0));
|
||||
memcpy(&qh1, x1->qh, sizeof(qh1));
|
||||
|
||||
tmp0[0] = table_b2b_0[(qh0 >> 0) & 0xFF];
|
||||
tmp0[1] = table_b2b_0[(qh0 >> 8) & 0xFF];
|
||||
tmp0[2] = table_b2b_0[(qh0 >> 16) & 0xFF];
|
||||
tmp0[3] = table_b2b_0[(qh0 >> 24) ];
|
||||
|
||||
tmp1[0] = table_b2b_0[(qh1 >> 0) & 0xFF];
|
||||
tmp1[1] = table_b2b_0[(qh1 >> 8) & 0xFF];
|
||||
tmp1[2] = table_b2b_0[(qh1 >> 16) & 0xFF];
|
||||
tmp1[3] = table_b2b_0[(qh1 >> 24) ];
|
||||
|
||||
int8x16_t v_qh0l = vec_xl(0, (const int8_t *)(tmp0 + 0));
|
||||
int8x16_t v_qh0h = vec_xl(0, (const int8_t *)(tmp0 + 2));
|
||||
int8x16_t v_qh1l = vec_xl(0, (const int8_t *)(tmp1 + 0));
|
||||
int8x16_t v_qh1h = vec_xl(0, (const int8_t *)(tmp1 + 2));
|
||||
|
||||
// required for fixing the byteorder
|
||||
v_qh0l = vec_perm(v_qh0l, v_qh0l, v_kperm);
|
||||
v_qh0h = vec_perm(v_qh0h, v_qh0h, v_kperm);
|
||||
v_qh1l = vec_perm(v_qh1l, v_qh1l, v_kperm);
|
||||
v_qh1h = vec_perm(v_qh1h, v_qh1h, v_kperm);
|
||||
|
||||
const uint8x16_t v_x0 = vec_xl(0, x0->qs);
|
||||
const uint8x16_t v_x1 = vec_xl(0, x1->qs);
|
||||
|
||||
const int8x16_t v_x0l = (int8x16_t)vec_and(v_x0, v_m);
|
||||
const int8x16_t v_x0h = (int8x16_t)vec_sr(v_x0, 4);
|
||||
const int8x16_t v_x1l = (int8x16_t)vec_and(v_x1, v_m);
|
||||
const int8x16_t v_x1h = (int8x16_t)vec_sr(v_x1, 4);
|
||||
|
||||
const int8x16_t v_x0lf = vec_or(v_x0l, v_qh0l);
|
||||
const int8x16_t v_x0hf = vec_or(v_x0h, v_qh0h);
|
||||
const int8x16_t v_x1lf = vec_or(v_x1l, v_qh1l);
|
||||
const int8x16_t v_x1hf = vec_or(v_x1h, v_qh1h);
|
||||
|
||||
const int8x16_t v_y0l = vec_xl(0 , y0->qs);
|
||||
const int8x16_t v_y0h = vec_xl(QK8_1/2, y0->qs);
|
||||
const int8x16_t v_y1l = vec_xl(0 , y1->qs);
|
||||
const int8x16_t v_y1h = vec_xl(QK8_1/2, y1->qs);
|
||||
|
||||
const int32x4_t v_xy0 = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_x0lf, v_y0l), v_x0hf, v_y0h);
|
||||
const int32x4_t v_xy1 = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_x1lf, v_y1l), v_x1hf, v_y1h);
|
||||
|
||||
const float32x4_t v_xy0f = vec_float(v_xy0);
|
||||
const float32x4_t v_xy1f = vec_float(v_xy1);
|
||||
|
||||
const float32x4_t v_d0 = vec_splats(GGML_CPU_FP16_TO_FP32(x0->d) * GGML_CPU_FP16_TO_FP32(y0->d));
|
||||
const float32x4_t v_d1 = vec_splats(GGML_CPU_FP16_TO_FP32(x1->d) * GGML_CPU_FP16_TO_FP32(y1->d));
|
||||
|
||||
v_sum0 = vec_madd(v_xy0f, v_d0, v_sum0);
|
||||
v_sum1 = vec_madd(v_xy1f, v_d1, v_sum1);
|
||||
}
|
||||
|
||||
sumf += vec_hsum(v_sum0) + vec_hsum(v_sum1) + summs0 + summs1;
|
||||
|
||||
#pragma GCC unroll 4
|
||||
for (; ib < nb; ++ib) {
|
||||
const block_q5_1 * GGML_RESTRICT x0 = &x[ib];
|
||||
const block_q8_1 * GGML_RESTRICT y0 = &y[ib];
|
||||
|
||||
float summs = GGML_CPU_FP16_TO_FP32(x0->m) * GGML_CPU_FP16_TO_FP32(y0->s);
|
||||
|
||||
uint32_t qh;
|
||||
memcpy(&qh, x0->qh, sizeof(qh));
|
||||
|
||||
uint64_t tmp[4];
|
||||
tmp[0] = table_b2b_0[(qh >> 0) & 0xFF];
|
||||
tmp[1] = table_b2b_0[(qh >> 8) & 0xFF];
|
||||
tmp[2] = table_b2b_0[(qh >> 16) & 0xFF];
|
||||
tmp[3] = table_b2b_0[(qh >> 24) ];
|
||||
|
||||
int8x16_t v_qhl = vec_xl(0, (const int8_t *)(tmp + 0));
|
||||
int8x16_t v_qhh = vec_xl(0, (const int8_t *)(tmp + 2));
|
||||
|
||||
// required for fixing the byteorder
|
||||
v_qhl = vec_perm(v_qhl, v_qhl, v_kperm);
|
||||
v_qhh = vec_perm(v_qhh, v_qhh, v_kperm);
|
||||
|
||||
const uint8x16_t v_x = vec_xl(0, x0->qs);
|
||||
const int8x16_t v_xl = (int8x16_t)vec_and(v_x, v_m);
|
||||
const int8x16_t v_xh = (int8x16_t)vec_sr(v_x, 4);
|
||||
|
||||
const int8x16_t v_xlf = vec_or(v_xl, v_qhl);
|
||||
const int8x16_t v_xhf = vec_or(v_xh, v_qhh);
|
||||
|
||||
const int8x16_t v_yl = vec_xl(0 , y0->qs);
|
||||
const int8x16_t v_yh = vec_xl(QK8_1/2, y0->qs);
|
||||
|
||||
const int32x4_t v_xy = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_xlf, v_yl), v_xhf, v_yh);
|
||||
const float32x4_t v_xyf = vec_float(v_xy);
|
||||
|
||||
const float32x4_t v_d = vec_splats(GGML_CPU_FP16_TO_FP32(x0->d) * GGML_CPU_FP16_TO_FP32(y0->d));
|
||||
const float32x4_t v_acc = vec_madd(v_xyf, v_d, v_acc);
|
||||
|
||||
sumf += vec_hsum(v_acc) + summs;
|
||||
}
|
||||
|
||||
*s = sumf;
|
||||
#else
|
||||
UNUSED(nb);
|
||||
UNUSED(x);
|
||||
UNUSED(y);
|
||||
UNUSED(ib);
|
||||
UNUSED(sumf);
|
||||
ggml_vec_dot_q5_1_q8_1_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
#endif
|
||||
}
|
||||
|
||||
void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
const int qk = QK8_0;
|
||||
const int nb = n / qk;
|
||||
|
|
|
|||
|
|
@ -486,6 +486,14 @@ inline static int16x8_t vec_padd_s16(int16x8_t a, int16x8_t b) {
|
|||
return v_abo + v_abe;
|
||||
}
|
||||
|
||||
/**
|
||||
* @see https://github.com/ggml-org/llama.cpp/pull/14037
|
||||
*/
|
||||
inline float vec_hsum(float32x4_t v) {
|
||||
float32x4_t v_temp = v + vec_reve(v);
|
||||
return v_temp[0] + v_temp[1];
|
||||
}
|
||||
|
||||
inline static int32x4_t ggml_vec_dot(int32x4_t acc, int8x16_t a, int8x16_t b) {
|
||||
const int16x8_t p = vec_mule(a, b) + vec_mulo(a, b);
|
||||
return acc + (vec_unpackh(p) + vec_unpackl(p));
|
||||
|
|
|
|||
|
|
@ -2664,6 +2664,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|||
{
|
||||
ggml_compute_forward_conv_2d(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_CONV_3D:
|
||||
{
|
||||
ggml_compute_forward_conv_3d(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_CONV_2D_DW:
|
||||
{
|
||||
ggml_compute_forward_conv_2d_dw(params, tensor);
|
||||
|
|
@ -3077,6 +3081,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
|||
case GGML_OP_IM2COL:
|
||||
case GGML_OP_IM2COL_BACK:
|
||||
case GGML_OP_CONV_2D:
|
||||
case GGML_OP_CONV_3D:
|
||||
case GGML_OP_CONV_2D_DW:
|
||||
case GGML_OP_CONV_TRANSPOSE_1D:
|
||||
case GGML_OP_CONV_TRANSPOSE_2D:
|
||||
|
|
@ -3627,6 +3632,7 @@ struct ggml_cplan ggml_graph_plan(
|
|||
}
|
||||
} break;
|
||||
case GGML_OP_CONV_2D:
|
||||
case GGML_OP_CONV_3D:
|
||||
{
|
||||
cur = GGML_IM2COL_WORK_SIZE;
|
||||
} break;
|
||||
|
|
|
|||
|
|
@ -7207,6 +7207,148 @@ void ggml_compute_forward_conv_2d(
|
|||
ggml_compute_forward_conv_2d_impl(params, src0, src1, dst, src0->type);
|
||||
}
|
||||
|
||||
// ggml_compute_forward_conv_3d
|
||||
|
||||
static void ggml_compute_forward_conv_3d_impl(const ggml_compute_params * params,
|
||||
const ggml_tensor * kernel,
|
||||
const ggml_tensor * src,
|
||||
ggml_tensor * dst,
|
||||
ggml_type kernel_type) {
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous(kernel));
|
||||
GGML_ASSERT(kernel_type == GGML_TYPE_F16 || kernel_type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(kernel->type == kernel_type);
|
||||
|
||||
const ggml_type_traits * traits = ggml_get_type_traits(kernel_type);
|
||||
|
||||
const int32_t s0 = dst->op_params[0];
|
||||
const int32_t s1 = dst->op_params[1];
|
||||
const int32_t s2 = dst->op_params[2];
|
||||
const int32_t p0 = dst->op_params[3];
|
||||
const int32_t p1 = dst->op_params[4];
|
||||
const int32_t p2 = dst->op_params[5];
|
||||
const int32_t d0 = dst->op_params[6];
|
||||
const int32_t d1 = dst->op_params[7];
|
||||
const int32_t d2 = dst->op_params[8];
|
||||
const int32_t c = dst->op_params[9];
|
||||
const int32_t n = dst->op_params[10];
|
||||
const int32_t oc = dst->op_params[11];
|
||||
|
||||
const int64_t src_w = src->ne[0];
|
||||
const int64_t src_h = src->ne[1];
|
||||
const int64_t src_d = src->ne[2];
|
||||
const int64_t knl_w = kernel->ne[0];
|
||||
const int64_t knl_h = kernel->ne[1];
|
||||
const int64_t knl_d = kernel->ne[2];
|
||||
const int64_t dst_w = dst->ne[0];
|
||||
const int64_t dst_h = dst->ne[1];
|
||||
const int64_t dst_d = dst->ne[2];
|
||||
|
||||
const float * src_data = (float *) src->data;
|
||||
void * knl_data = kernel->data;
|
||||
float * dst_data = (float *) dst->data;
|
||||
|
||||
const int64_t knl_n_per_channel = knl_w * knl_h * knl_d;
|
||||
const int64_t knl_n_total = knl_n_per_channel * c;
|
||||
const int64_t patch_total = n * dst_w * dst_h * dst_d;
|
||||
|
||||
const int64_t space_per_patch = knl_n_total * traits->type_size + oc * sizeof(float);
|
||||
const int64_t batch_size = params->wsize / space_per_patch;
|
||||
const int64_t patches_per_batch = batch_size > 8 ? (batch_size / 8) * 8 : batch_size;
|
||||
const int64_t batch_n = (patch_total + patches_per_batch - 1) / patches_per_batch;
|
||||
|
||||
GGML_ASSERT(patches_per_batch > 0 && batch_size >= 1);
|
||||
|
||||
void * tmp = params->wdata;
|
||||
|
||||
for (int64_t batch_i = 0; batch_i < batch_n; ++batch_i) {
|
||||
const int64_t patch_start_batch = batch_i * patches_per_batch;
|
||||
const int64_t patch_end_batch = std::min(patch_start_batch + patches_per_batch, patch_total);
|
||||
const int64_t patch_n_in_batch = patch_end_batch - patch_start_batch;
|
||||
|
||||
const int64_t patch_per_thread = (patch_n_in_batch + params->nth - 1) / params->nth;
|
||||
const int64_t patch_start = patch_start_batch + params->ith * patch_per_thread;
|
||||
const int64_t patch_end = std::min(patch_start + patch_per_thread, patch_end_batch);
|
||||
|
||||
for (int64_t p = patch_start; p < patch_end; ++p) {
|
||||
const int64_t p_in_batch = p % (dst_w * dst_h * dst_d);
|
||||
const int64_t p_in_depth = p_in_batch % (dst_w * dst_h);
|
||||
const int64_t batch_idx = p / (dst_w * dst_h * dst_d);
|
||||
const int64_t dst_z = p_in_batch / (dst_w * dst_h);
|
||||
const int64_t dst_y = p_in_depth / dst_w;
|
||||
const int64_t dst_x = p_in_depth % dst_w;
|
||||
|
||||
char * dst_row = (char *) tmp + (p % patches_per_batch) * knl_n_total * traits->type_size;
|
||||
|
||||
for (int64_t ic = 0; ic < c; ++ic) {
|
||||
for (int64_t kz = 0; kz < knl_d; ++kz) {
|
||||
for (int64_t ky = 0; ky < knl_h; ++ky) {
|
||||
for (int64_t kx = 0; kx < knl_w; ++kx) {
|
||||
const int64_t sz = dst_z * s2 + kz * d2 - p2;
|
||||
const int64_t sy = dst_y * s1 + ky * d1 - p1;
|
||||
const int64_t sx = dst_x * s0 + kx * d0 - p0;
|
||||
|
||||
int64_t dst_idx = ic * knl_n_per_channel + kz * (knl_h * knl_w) + ky * knl_w + kx;
|
||||
|
||||
float src_val;
|
||||
if (sz < 0 || sz >= src_d || sy < 0 || sy >= src_h || sx < 0 || sx >= src_w) {
|
||||
src_val = 0.0f;
|
||||
} else {
|
||||
const int64_t cn_idx = batch_idx * c + ic;
|
||||
const float * src_ptr = (const float *)((const char *)src_data + sx*src->nb[0] + sy*src->nb[1] + sz*src->nb[2] + cn_idx*src->nb[3]);
|
||||
src_val = *src_ptr;
|
||||
}
|
||||
|
||||
char * element_ptr = dst_row + dst_idx * traits->type_size;
|
||||
if (kernel_type == GGML_TYPE_F32) {
|
||||
*(float *)element_ptr = src_val;
|
||||
} else if (kernel_type == GGML_TYPE_F16) {
|
||||
*(ggml_fp16_t *)element_ptr = GGML_CPU_FP32_TO_FP16(src_val);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ggml_barrier(params->threadpool);
|
||||
|
||||
float * gemm_output = (float *) ((char *) tmp + patches_per_batch * knl_n_total * traits->type_size);
|
||||
ggml_call_mul_mat(kernel_type, params, patch_n_in_batch, oc, knl_n_total, tmp, knl_data, gemm_output);
|
||||
|
||||
ggml_barrier(params->threadpool);
|
||||
|
||||
const int64_t permute_per_thread = (patch_n_in_batch + params->nth - 1) / params->nth;
|
||||
const int64_t permute_start = params->ith * permute_per_thread;
|
||||
const int64_t permute_end = std::min(permute_start + permute_per_thread, patch_n_in_batch);
|
||||
|
||||
for (int64_t i = permute_start; i < permute_end; ++i) {
|
||||
const int64_t p = patch_start_batch + i;
|
||||
const int64_t p_in_batch = p % (dst_w * dst_h * dst_d);
|
||||
const int64_t p_in_depth = p_in_batch % (dst_w * dst_h);
|
||||
const int64_t batch_idx = p / (dst_w * dst_h * dst_d);
|
||||
const int64_t dst_z = p_in_batch / (dst_w * dst_h);
|
||||
const int64_t dst_y = p_in_depth / dst_w;
|
||||
const int64_t dst_x = p_in_depth % dst_w;
|
||||
|
||||
for (int64_t ioc = 0; ioc < oc; ++ioc) {
|
||||
const float value = gemm_output[i * oc + ioc];
|
||||
const int64_t ocn_idx = batch_idx * oc + ioc;
|
||||
float * dst_ptr = (float *)((char *)dst_data + dst_x*dst->nb[0] + dst_y*dst->nb[1] + dst_z*dst->nb[2] + ocn_idx*dst->nb[3]);
|
||||
*dst_ptr = value;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_compute_forward_conv_3d(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
ggml_compute_forward_conv_3d_impl(params, src0, src1, dst, src0->type);
|
||||
}
|
||||
|
||||
// ggml_compute_forward_conv_transpose_2d
|
||||
|
||||
void ggml_compute_forward_conv_transpose_2d(
|
||||
|
|
|
|||
|
|
@ -70,6 +70,7 @@ void ggml_compute_forward_conv_transpose_1d(const struct ggml_compute_params * p
|
|||
void ggml_compute_forward_im2col(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_im2col_back_f32(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_conv_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_conv_3d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_conv_transpose_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_conv_2d_dw(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_pool_1d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
|
|
|
|||
|
|
@ -51,6 +51,7 @@ bool g_mul_mat_q = true;
|
|||
#include "ggml-cuda/wkv.cuh"
|
||||
#include "ggml-cuda/gla.cuh"
|
||||
#include "ggml-cuda/set-rows.cuh"
|
||||
#include "ggml-cuda/pad_reflect_1d.cuh"
|
||||
#include "ggml.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
|
@ -2365,6 +2366,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
|||
case GGML_OP_PAD:
|
||||
ggml_cuda_op_pad(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_PAD_REFLECT_1D:
|
||||
ggml_cuda_op_pad_reflect_1d(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_ARANGE:
|
||||
ggml_cuda_op_arange(ctx, dst);
|
||||
break;
|
||||
|
|
@ -3503,6 +3507,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|||
return ggml_is_contiguous(op->src[0]);
|
||||
case GGML_OP_UPSCALE:
|
||||
case GGML_OP_PAD:
|
||||
case GGML_OP_PAD_REFLECT_1D:
|
||||
case GGML_OP_ARANGE:
|
||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
|
|
|
|||
82
ggml/src/ggml-cuda/pad_reflect_1d.cu
Normal file
82
ggml/src/ggml-cuda/pad_reflect_1d.cu
Normal file
|
|
@ -0,0 +1,82 @@
|
|||
#include "pad_reflect_1d.cuh"
|
||||
|
||||
static __global__ void pad_reflect_1d_kernel_f32(
|
||||
const void * __restrict__ src0,
|
||||
void * __restrict__ dst,
|
||||
const int64_t ne0,
|
||||
const int64_t ne00,
|
||||
const int64_t ne01,
|
||||
const int64_t ne02,
|
||||
const int64_t ne03,
|
||||
const int64_t nb00,
|
||||
const int64_t nb01,
|
||||
const int64_t nb02,
|
||||
const int64_t nb03,
|
||||
const int64_t nb0,
|
||||
const int64_t nb1,
|
||||
const int64_t nb2,
|
||||
const int64_t nb3,
|
||||
const int p0,
|
||||
const int p1) {
|
||||
|
||||
const int64_t i3 = blockIdx.z;
|
||||
const int64_t i2 = blockIdx.y;
|
||||
const int64_t i1 = blockIdx.x;
|
||||
|
||||
if (i1 >= ne01 || i2 >= ne02 || i3 >= ne03) {
|
||||
return;
|
||||
}
|
||||
|
||||
const char * src0_ptr = (const char *)src0 + i3*nb03 + i2*nb02 + i1*nb01;
|
||||
char * dst_ptr = (char *)dst + i3*nb3 + i2*nb2 + i1*nb1;
|
||||
|
||||
for (int64_t i0 = threadIdx.x; i0 < ne0; i0 += blockDim.x) {
|
||||
float value;
|
||||
|
||||
if (i0 < p0) {
|
||||
// Left padding - reflect
|
||||
value = *(const float *)(src0_ptr + (p0 - i0) * nb00);
|
||||
} else if (i0 < ne0 - p1) {
|
||||
// Middle - copy
|
||||
value = *(const float *)(src0_ptr + (i0 - p0) * nb00);
|
||||
} else {
|
||||
// Right padding - reflect
|
||||
int64_t src_idx = (ne0 - p1 - p0) - (p1 + 1 - (ne0 - i0)) - 1;
|
||||
value = *(const float *)(src0_ptr + src_idx * nb00);
|
||||
}
|
||||
|
||||
*(float *)(dst_ptr + i0 * nb0) = value;
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_op_pad_reflect_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
|
||||
const int32_t * opts = (const int32_t *) dst->op_params;
|
||||
const int p0 = opts[0];
|
||||
const int p1 = opts[1];
|
||||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t ne01 = src0->ne[1];
|
||||
const int64_t ne02 = src0->ne[2];
|
||||
const int64_t ne03 = src0->ne[3];
|
||||
|
||||
const int64_t ne0 = dst->ne[0];
|
||||
|
||||
GGML_ASSERT(ne0 == ne00 + p0 + p1);
|
||||
|
||||
const dim3 block_dims(CUDA_PAD_REFLECT_1D_BLOCK_SIZE, 1, 1);
|
||||
const dim3 grid_dims(ne01, ne02, ne03);
|
||||
|
||||
pad_reflect_1d_kernel_f32<<<grid_dims, block_dims, 0, stream>>>(
|
||||
src0->data, dst->data,
|
||||
ne0, ne00, ne01, ne02, ne03,
|
||||
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
|
||||
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
|
||||
p0, p1
|
||||
);
|
||||
}
|
||||
5
ggml/src/ggml-cuda/pad_reflect_1d.cuh
Normal file
5
ggml/src/ggml-cuda/pad_reflect_1d.cuh
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
#include "common.cuh"
|
||||
|
||||
#define CUDA_PAD_REFLECT_1D_BLOCK_SIZE 256
|
||||
|
||||
void ggml_cuda_op_pad_reflect_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
|
@ -506,6 +506,7 @@ struct vk_device_struct {
|
|||
vk_pipeline pipeline_l2_norm_f32;
|
||||
|
||||
// [src/dst 0=fp32,1=fp16]
|
||||
vk_pipeline pipeline_exp[2];
|
||||
vk_pipeline pipeline_gelu[2];
|
||||
vk_pipeline pipeline_gelu_erf[2];
|
||||
vk_pipeline pipeline_gelu_quick[2];
|
||||
|
|
@ -545,8 +546,8 @@ struct vk_device_struct {
|
|||
vk_pipeline pipeline_opt_step_sgd_f32;
|
||||
vk_pipeline pipeline_conv2d_f32[CONV_SHAPE_COUNT];
|
||||
vk_pipeline pipeline_conv2d_f16_f32[CONV_SHAPE_COUNT];
|
||||
vk_pipeline pipeline_conv2d_dw_whcn_f32;
|
||||
vk_pipeline pipeline_conv2d_dw_cwhn_f32;
|
||||
vk_pipeline pipeline_conv2d_dw_whcn_f32, pipeline_conv2d_dw_whcn_f16_f32;
|
||||
vk_pipeline pipeline_conv2d_dw_cwhn_f32, pipeline_conv2d_dw_cwhn_f16_f32;
|
||||
|
||||
// [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_cm2[GGML_TYPE_COUNT][FA_HEAD_SIZE_COUNT][2][2][2];
|
||||
|
|
@ -1209,6 +1210,10 @@ struct ggml_backend_vk_context {
|
|||
vk::Fence fence, almost_ready_fence;
|
||||
bool almost_ready_fence_pending {};
|
||||
|
||||
// Cache most recent tensor that was converted into prealloc_y, and what pipeline it used to convert.
|
||||
vk_pipeline_struct * prealloc_y_last_pipeline_used {};
|
||||
const ggml_tensor * prealloc_y_last_tensor_used {};
|
||||
|
||||
vk_buffer buffer_pool[MAX_VK_BUFFERS];
|
||||
|
||||
vk_context_ref compute_ctx;
|
||||
|
|
@ -3078,6 +3083,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
CREATE_UNARY(exp)
|
||||
CREATE_UNARY(gelu)
|
||||
CREATE_UNARY(gelu_erf)
|
||||
CREATE_UNARY(gelu_quick)
|
||||
|
|
@ -3267,6 +3273,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f32, "conv2d_dw_cwhn_f32", conv2d_dw_cwhn_f32_len, conv2d_dw_cwhn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f16_f32, "conv2d_dw_whcn_f16_f32", conv2d_dw_whcn_f16_f32_len, conv2d_dw_whcn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f16_f32, "conv2d_dw_cwhn_f16_f32", conv2d_dw_cwhn_f16_f32_len, conv2d_dw_cwhn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
for (auto &c : compiles) {
|
||||
c.wait();
|
||||
|
|
@ -5681,10 +5689,20 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
|||
ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0, { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz * ne02 * ne03 }, vk_subbuffer{ d_X, 0, x_sz * ne02 * ne03 } }, pc, { (uint32_t)(x_ne * ne02 * ne03), 1, 1});
|
||||
}
|
||||
if (y_non_contig) {
|
||||
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
|
||||
if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() ||
|
||||
ctx->prealloc_y_last_tensor_used != src1) {
|
||||
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
|
||||
ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get();
|
||||
ctx->prealloc_y_last_tensor_used = src1;
|
||||
}
|
||||
}
|
||||
if (quantize_y) {
|
||||
ggml_vk_quantize_q8_1(ctx, subctx, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }, y_ne * ne12 * ne13);
|
||||
if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() ||
|
||||
ctx->prealloc_y_last_tensor_used != src1) {
|
||||
ggml_vk_quantize_q8_1(ctx, subctx, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }, y_ne * ne12 * ne13);
|
||||
ctx->prealloc_y_last_pipeline_used = to_q8_1.get();
|
||||
ctx->prealloc_y_last_tensor_used = src1;
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t stride_batch_x = ne00*ne01;
|
||||
|
|
@ -5859,7 +5877,12 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
|||
}
|
||||
if (y_non_contig) {
|
||||
GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne);
|
||||
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
|
||||
if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() ||
|
||||
ctx->prealloc_y_last_tensor_used != src1) {
|
||||
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
|
||||
ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get();
|
||||
ctx->prealloc_y_last_tensor_used = src1;
|
||||
}
|
||||
}
|
||||
|
||||
// For batch_n, the A matrix is the same for each batch, and B/D use the row stride as the batch stride
|
||||
|
|
@ -6289,7 +6312,12 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
|||
{ vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz * ne02 * ne03 }, vk_subbuffer{ d_X, 0, x_sz * ne02 * ne03 } }, pc, { (uint32_t)(x_ne * ne02 * ne03), 1, 1});
|
||||
}
|
||||
if (y_non_contig) {
|
||||
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
|
||||
if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() ||
|
||||
ctx->prealloc_y_last_tensor_used != src1) {
|
||||
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
|
||||
ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get();
|
||||
ctx->prealloc_y_last_tensor_used = src1;
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t stride_batch_x = ne00*ne01;
|
||||
|
|
@ -6477,7 +6505,12 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
|
|||
}
|
||||
if (y_non_contig) {
|
||||
GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne);
|
||||
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
|
||||
if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() ||
|
||||
ctx->prealloc_y_last_tensor_used != src1) {
|
||||
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
|
||||
ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get();
|
||||
ctx->prealloc_y_last_tensor_used = src1;
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t stride_batch_y = ne10*ne11;
|
||||
|
|
@ -6521,22 +6554,29 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
|
|||
GGML_ASSERT(nei0 <= 4096);
|
||||
const uint32_t split_size = std::min(nei1, 4096u / nei0);
|
||||
|
||||
ggml_tensor src1_copy = *src1;
|
||||
ggml_tensor src2_copy = *src2;
|
||||
ggml_tensor dst_copy = *dst;
|
||||
if (split_size == nei1) {
|
||||
ggml_vk_mul_mat_id_q_f16(ctx, subctx, src0, src1, src2, dst, dryrun);
|
||||
} else {
|
||||
ggml_tensor src1_copy = *src1;
|
||||
ggml_tensor src2_copy = *src2;
|
||||
ggml_tensor dst_copy = *dst;
|
||||
|
||||
for (uint32_t token_start = 0; token_start < nei1; token_start += split_size) {
|
||||
const uint32_t n_tokens = std::min(split_size, nei1 - token_start);
|
||||
for (uint32_t token_start = 0; token_start < nei1; token_start += split_size) {
|
||||
const uint32_t n_tokens = std::min(split_size, nei1 - token_start);
|
||||
|
||||
src1_copy.view_offs = src1->view_offs + token_start * src1_copy.nb[2];
|
||||
src2_copy.view_offs = src2->view_offs + token_start * src2_copy.nb[1];
|
||||
dst_copy.view_offs = dst->view_offs + token_start * dst_copy.nb[2];
|
||||
src1_copy.view_offs = src1->view_offs + token_start * src1_copy.nb[2];
|
||||
src2_copy.view_offs = src2->view_offs + token_start * src2_copy.nb[1];
|
||||
dst_copy.view_offs = dst->view_offs + token_start * dst_copy.nb[2];
|
||||
|
||||
src1_copy.ne[2] = n_tokens;
|
||||
src2_copy.ne[1] = n_tokens;
|
||||
dst_copy.ne[2] = n_tokens;
|
||||
src1_copy.ne[2] = n_tokens;
|
||||
src2_copy.ne[1] = n_tokens;
|
||||
dst_copy.ne[2] = n_tokens;
|
||||
|
||||
ggml_vk_mul_mat_id_q_f16(ctx, subctx, src0, &src1_copy, &src2_copy, &dst_copy, dryrun);
|
||||
ggml_vk_mul_mat_id_q_f16(ctx, subctx, src0, &src1_copy, &src2_copy, &dst_copy, dryrun);
|
||||
// invalidate cached prealloc_y, can't cache based on the copy of the ggml_tensor
|
||||
ctx->prealloc_y_last_pipeline_used = {};
|
||||
ctx->prealloc_y_last_tensor_used = nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -7127,6 +7167,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|||
}
|
||||
|
||||
switch (ggml_get_unary_op(dst)) {
|
||||
case GGML_UNARY_OP_EXP:
|
||||
return ctx->device->pipeline_exp[dst->type == GGML_TYPE_F16];
|
||||
case GGML_UNARY_OP_SILU:
|
||||
return ctx->device->pipeline_silu[dst->type == GGML_TYPE_F16];
|
||||
case GGML_UNARY_OP_GELU:
|
||||
|
|
@ -7336,6 +7378,12 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|||
} else if (ggml_is_contiguous_channels(src1)) {
|
||||
return ctx->device->pipeline_conv2d_dw_cwhn_f32;
|
||||
}
|
||||
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
|
||||
if (ggml_is_contiguous(src1)) {
|
||||
return ctx->device->pipeline_conv2d_dw_whcn_f16_f32;
|
||||
} else if (ggml_is_contiguous_channels(src1)) {
|
||||
return ctx->device->pipeline_conv2d_dw_cwhn_f16_f32;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
default:
|
||||
|
|
@ -9732,6 +9780,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|||
return false;
|
||||
case GGML_OP_UNARY:
|
||||
switch (ggml_get_unary_op(node)) {
|
||||
case GGML_UNARY_OP_EXP:
|
||||
case GGML_UNARY_OP_SILU:
|
||||
case GGML_UNARY_OP_GELU:
|
||||
case GGML_UNARY_OP_GELU_ERF:
|
||||
|
|
@ -10009,6 +10058,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|||
break;
|
||||
case GGML_OP_UNARY:
|
||||
switch (ggml_get_unary_op(node)) {
|
||||
case GGML_UNARY_OP_EXP:
|
||||
case GGML_UNARY_OP_SILU:
|
||||
case GGML_UNARY_OP_GELU:
|
||||
case GGML_UNARY_OP_GELU_ERF:
|
||||
|
|
@ -10245,6 +10295,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
|
|||
break;
|
||||
case GGML_OP_UNARY:
|
||||
switch (ggml_get_unary_op(tensor)) {
|
||||
case GGML_UNARY_OP_EXP:
|
||||
case GGML_UNARY_OP_SILU:
|
||||
case GGML_UNARY_OP_GELU:
|
||||
case GGML_UNARY_OP_GELU_ERF:
|
||||
|
|
@ -10341,6 +10392,7 @@ static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) {
|
|||
ggml_vk_pool_free(ctx, buffer);
|
||||
}
|
||||
ctx->gc.temp_buffers.clear();
|
||||
ctx->prealloc_y_last_pipeline_used = {};
|
||||
|
||||
ggml_vk_command_pool_cleanup(ctx->device, ctx->compute_cmd_pool);
|
||||
ggml_vk_command_pool_cleanup(ctx->device, ctx->transfer_cmd_pool);
|
||||
|
|
@ -10376,6 +10428,7 @@ static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) {
|
|||
ggml_vk_destroy_buffer(ctx->prealloc_x);
|
||||
ggml_vk_destroy_buffer(ctx->prealloc_y);
|
||||
ggml_vk_destroy_buffer(ctx->prealloc_split_k);
|
||||
ctx->prealloc_y_last_pipeline_used = nullptr;
|
||||
|
||||
for (auto& buffer : ctx->buffer_pool) {
|
||||
ggml_vk_destroy_buffer(buffer);
|
||||
|
|
@ -10924,6 +10977,9 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|||
compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->device->query_pool, 0);
|
||||
}
|
||||
|
||||
ctx->prealloc_y_last_pipeline_used = nullptr;
|
||||
ctx->prealloc_y_last_tensor_used = nullptr;
|
||||
|
||||
// Submit after enough work has accumulated, to overlap CPU cmdbuffer generation with GPU execution.
|
||||
// Estimate the amount of matmul work by looking at the weight matrix size, and submit every 100MB
|
||||
// (and scaled down based on model size, so smaller models submit earlier).
|
||||
|
|
@ -11155,6 +11211,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|||
switch (op->op) {
|
||||
case GGML_OP_UNARY:
|
||||
switch (ggml_get_unary_op(op)) {
|
||||
case GGML_UNARY_OP_EXP:
|
||||
case GGML_UNARY_OP_GELU:
|
||||
case GGML_UNARY_OP_GELU_ERF:
|
||||
case GGML_UNARY_OP_GELU_QUICK:
|
||||
|
|
@ -11954,6 +12011,9 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
|
|||
}
|
||||
} else if (tensor->op == GGML_OP_UNARY) {
|
||||
switch (ggml_get_unary_op(tensor)) {
|
||||
case GGML_UNARY_OP_EXP:
|
||||
tensor_clone = ggml_exp(ggml_ctx, src_clone[0]);
|
||||
break;
|
||||
case GGML_UNARY_OP_SILU:
|
||||
tensor_clone = ggml_silu(ggml_ctx, src_clone[0]);
|
||||
break;
|
||||
|
|
|
|||
20
ggml/src/ggml-vulkan/vulkan-shaders/exp.comp
Normal file
20
ggml/src/ggml-vulkan/vulkan-shaders/exp.comp
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
#version 450
|
||||
|
||||
#include "generic_head.comp"
|
||||
#include "types.comp"
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
|
||||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
||||
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
void main() {
|
||||
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
|
||||
|
||||
if (i >= p.KX) {
|
||||
return;
|
||||
}
|
||||
data_d[i] = D_TYPE(exp(float(data_a[i])));
|
||||
}
|
||||
|
|
@ -600,6 +600,8 @@ void process_shaders() {
|
|||
|
||||
string_to_spv("upscale_f32", "upscale.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
|
||||
string_to_spv("exp_f16", "exp.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("exp_f32", "exp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("gelu_f16", "gelu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("gelu_erf_f16", "gelu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
|
|
@ -692,6 +694,8 @@ void process_shaders() {
|
|||
|
||||
string_to_spv("conv2d_dw_whcn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}}));
|
||||
string_to_spv("conv2d_dw_cwhn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}}));
|
||||
string_to_spv("conv2d_dw_whcn_f16_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}}));
|
||||
string_to_spv("conv2d_dw_cwhn_f16_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}}));
|
||||
|
||||
string_to_spv("roll_f32", "roll.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
|
||||
|
|
|
|||
|
|
@ -991,6 +991,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|||
"IM2COL",
|
||||
"IM2COL_BACK",
|
||||
"CONV_2D",
|
||||
"CONV_3D",
|
||||
"CONV_2D_DW",
|
||||
"CONV_TRANSPOSE_2D",
|
||||
"POOL_1D",
|
||||
|
|
@ -1033,7 +1034,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|||
"GLU",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 88, "GGML_OP_COUNT != 88");
|
||||
static_assert(GGML_OP_COUNT == 89, "GGML_OP_COUNT != 89");
|
||||
|
||||
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"none",
|
||||
|
|
@ -1093,6 +1094,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|||
"im2col(x)",
|
||||
"im2col_back(x)",
|
||||
"conv_2d(x)",
|
||||
"conv_3d(x)",
|
||||
"conv_2d_dw(x)",
|
||||
"conv_transpose_2d(x)",
|
||||
"pool_1d(x)",
|
||||
|
|
@ -1135,7 +1137,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|||
"glu(x)",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 88, "GGML_OP_COUNT != 88");
|
||||
static_assert(GGML_OP_COUNT == 89, "GGML_OP_COUNT != 89");
|
||||
|
||||
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
||||
|
||||
|
|
@ -4496,6 +4498,56 @@ struct ggml_tensor * ggml_conv_2d_direct(
|
|||
return result;
|
||||
}
|
||||
|
||||
// ggml_conv_3d
|
||||
|
||||
struct ggml_tensor * ggml_conv_3d(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
int s0,
|
||||
int s1,
|
||||
int s2,
|
||||
int p0,
|
||||
int p1,
|
||||
int p2,
|
||||
int d0,
|
||||
int d1,
|
||||
int d2,
|
||||
int c,
|
||||
int n,
|
||||
int oc) {
|
||||
|
||||
GGML_ASSERT(a->ne[3] == (int64_t) c * oc);
|
||||
GGML_ASSERT(b->ne[3] == (int64_t) c * n);
|
||||
|
||||
int64_t ne[4];
|
||||
ne[0] = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0);
|
||||
ne[1] = ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1);
|
||||
ne[2] = ggml_calc_conv_output_size(b->ne[2], a->ne[2], s2, p2, d2);
|
||||
ne[3] = (int64_t) oc * n;
|
||||
|
||||
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
||||
|
||||
ggml_set_op_params_i32(result, 0, s0);
|
||||
ggml_set_op_params_i32(result, 1, s1);
|
||||
ggml_set_op_params_i32(result, 2, s2);
|
||||
ggml_set_op_params_i32(result, 3, p0);
|
||||
ggml_set_op_params_i32(result, 4, p1);
|
||||
ggml_set_op_params_i32(result, 5, p2);
|
||||
ggml_set_op_params_i32(result, 6, d0);
|
||||
ggml_set_op_params_i32(result, 7, d1);
|
||||
ggml_set_op_params_i32(result, 8, d2);
|
||||
ggml_set_op_params_i32(result, 9, c);
|
||||
ggml_set_op_params_i32(result, 10, n);
|
||||
ggml_set_op_params_i32(result, 11, oc);
|
||||
|
||||
result->op = GGML_OP_CONV_3D;
|
||||
result->src[0] = a;
|
||||
result->src[1] = b;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// ggml_conv_transpose_2d_p0
|
||||
|
||||
static int64_t ggml_calc_conv_transpose_output_size(int64_t ins, int64_t ks, int s, int p) {
|
||||
|
|
|
|||
|
|
@ -2590,6 +2590,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
|||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
],
|
||||
MODEL_ARCH.SMALLTHINKER: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
|
|
|
|||
111
include/llama.h
111
include/llama.h
|
|
@ -67,8 +67,6 @@ extern "C" {
|
|||
|
||||
typedef struct llama_memory_i * llama_memory_t;
|
||||
|
||||
struct llama_kv_cache; // DEPRECATED (use llama_memory instead)
|
||||
|
||||
typedef int32_t llama_pos;
|
||||
typedef int32_t llama_token;
|
||||
typedef int32_t llama_seq_id;
|
||||
|
|
@ -317,7 +315,7 @@ extern "C" {
|
|||
float yarn_beta_fast; // YaRN low correction dim
|
||||
float yarn_beta_slow; // YaRN high correction dim
|
||||
uint32_t yarn_orig_ctx; // YaRN original context size
|
||||
float defrag_thold; // defragment the KV cache if holes/size > thold, <= 0 disabled (default)
|
||||
float defrag_thold; // [DEPRECATED] defragment the KV cache if holes/size > thold, <= 0 disabled (default)
|
||||
|
||||
ggml_backend_sched_eval_callback cb_eval;
|
||||
void * cb_eval_user_data;
|
||||
|
|
@ -472,8 +470,6 @@ extern "C" {
|
|||
LLAMA_API llama_memory_t llama_get_memory (const struct llama_context * ctx);
|
||||
LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx); // TODO: rename to llama_get_pooling_type
|
||||
|
||||
DEPRECATED(LLAMA_API struct llama_kv_cache * llama_get_kv_self(struct llama_context * ctx), "use llama_get_memory instead");
|
||||
|
||||
LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model);
|
||||
LLAMA_API enum llama_rope_type llama_model_rope_type(const struct llama_model * model);
|
||||
|
||||
|
|
@ -670,111 +666,6 @@ extern "C" {
|
|||
// Check if the memory supports shifting
|
||||
LLAMA_API bool llama_memory_can_shift(llama_memory_t mem);
|
||||
|
||||
//
|
||||
// KV cache for self-attention (TODO: deprecate in favor of llama_memory)
|
||||
//
|
||||
|
||||
// Returns the number of tokens in the KV cache (slow, use only for debug)
|
||||
// If a KV cell has multiple sequences assigned to it, it will be counted multiple times
|
||||
DEPRECATED(LLAMA_API int32_t llama_kv_self_n_tokens(const struct llama_context * ctx),
|
||||
"Use llama_kv_self_seq_pos_max() and llama_kv_self_seq_pos_min() instead (https://github.com/ggml-org/llama.cpp/issues/13793)");
|
||||
|
||||
// Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
|
||||
DEPRECATED(LLAMA_API int32_t llama_kv_self_used_cells(const struct llama_context * ctx),
|
||||
"Use llama_kv_self_seq_pos_max() and llama_kv_self_seq_pos_min() instead (https://github.com/ggml-org/llama.cpp/issues/13793)");
|
||||
|
||||
// Clear the KV cache - both cell info is erased and KV data is zeroed
|
||||
DEPRECATED(LLAMA_API void llama_kv_self_clear(
|
||||
struct llama_context * ctx),
|
||||
"Use llama_memory_clear() instead");
|
||||
|
||||
// Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
|
||||
// Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails
|
||||
// seq_id < 0 : match any sequence
|
||||
// p0 < 0 : [0, p1]
|
||||
// p1 < 0 : [p0, inf)
|
||||
DEPRECATED(LLAMA_API bool llama_kv_self_seq_rm(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1),
|
||||
"Use llama_memory_seq_rm() instead");
|
||||
|
||||
// Copy all tokens that belong to the specified sequence to another sequence
|
||||
// Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence
|
||||
// p0 < 0 : [0, p1]
|
||||
// p1 < 0 : [p0, inf)
|
||||
DEPRECATED(LLAMA_API void llama_kv_self_seq_cp(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id_src,
|
||||
llama_seq_id seq_id_dst,
|
||||
llama_pos p0,
|
||||
llama_pos p1),
|
||||
"Use llama_memory_seq_cp() instead");
|
||||
|
||||
// Removes all tokens that do not belong to the specified sequence
|
||||
DEPRECATED(LLAMA_API void llama_kv_self_seq_keep(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id),
|
||||
"Use llama_memory_seq_keep() instead");
|
||||
|
||||
// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
|
||||
// If the KV cache is RoPEd, the KV data is updated accordingly:
|
||||
// - lazily on next llama_decode()
|
||||
// p0 < 0 : [0, p1]
|
||||
// p1 < 0 : [p0, inf)
|
||||
DEPRECATED(LLAMA_API void llama_kv_self_seq_add(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1,
|
||||
llama_pos delta),
|
||||
"Use llama_memory_seq_add() instead");
|
||||
|
||||
// Integer division of the positions by factor of `d > 1`
|
||||
// If the KV cache is RoPEd, the KV data is updated accordingly:
|
||||
// - lazily on next llama_decode()
|
||||
// p0 < 0 : [0, p1]
|
||||
// p1 < 0 : [p0, inf)
|
||||
DEPRECATED(LLAMA_API void llama_kv_self_seq_div(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1,
|
||||
int d),
|
||||
"Use llama_memory_seq_div() instead");
|
||||
|
||||
// Returns the smallest position present in the KV cache for the specified sequence
|
||||
// This is typically non-zero only for SWA caches
|
||||
// Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache
|
||||
// Return -1 if the sequence is empty
|
||||
DEPRECATED(LLAMA_API llama_pos llama_kv_self_seq_pos_min(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id),
|
||||
"Use llama_memory_seq_pos_min() instead");
|
||||
|
||||
// Returns the largest position present in the KV cache for the specified sequence
|
||||
// Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache
|
||||
// Return -1 if the sequence is empty
|
||||
DEPRECATED(LLAMA_API llama_pos llama_kv_self_seq_pos_max(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id),
|
||||
"Use llama_memory_seq_pos_max() instead");
|
||||
|
||||
// Defragment the KV cache
|
||||
// This will be applied:
|
||||
// - lazily on next llama_decode()
|
||||
DEPRECATED(LLAMA_API void llama_kv_self_defrag(struct llama_context * ctx),
|
||||
"simply remove this call, the context will automatically decide when to do a defragmentation based on 'defrag_thold'");
|
||||
|
||||
// Check if the context supports KV cache shifting
|
||||
DEPRECATED(LLAMA_API bool llama_kv_self_can_shift(const struct llama_context * ctx),
|
||||
"use llama_memory_can_shift() instead");
|
||||
|
||||
// Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
|
||||
DEPRECATED(LLAMA_API void llama_kv_self_update(struct llama_context * ctx),
|
||||
"simply remove this call, updates are applied lazily on the next llama_decode()");
|
||||
|
||||
//
|
||||
// State / sessions
|
||||
//
|
||||
|
|
|
|||
|
|
@ -2010,6 +2010,7 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
|||
{ LLM_TENSOR_SHORTCONV_OUTPROJ, "blk.%d.shortconv.out_proj" },
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
}
|
||||
},
|
||||
{
|
||||
|
|
|
|||
|
|
@ -39,7 +39,6 @@ llama_context::llama_context(
|
|||
cparams.yarn_attn_factor = params.yarn_attn_factor;
|
||||
cparams.yarn_beta_fast = params.yarn_beta_fast;
|
||||
cparams.yarn_beta_slow = params.yarn_beta_slow;
|
||||
cparams.defrag_thold = params.defrag_thold;
|
||||
cparams.embeddings = params.embeddings;
|
||||
cparams.offload_kqv = params.offload_kqv;
|
||||
cparams.flash_attn = params.flash_attn;
|
||||
|
|
@ -93,7 +92,7 @@ llama_context::llama_context(
|
|||
// the batch has to be at least GGML_KQ_MASK_PAD because we will be padding the KQ_mask
|
||||
// this is required by GPU kernels in order to avoid out-of-bounds accesses (e.g. ggml_flash_attn_ext)
|
||||
// ref: https://github.com/ggerganov/llama.cpp/pull/5021
|
||||
// TODO: this padding is not needed for the cache-less context so we should probably move it to llama_context_kv_self
|
||||
// TODO: this padding is not needed for the cache-less context so we should probably move it to llama_memory
|
||||
if (cparams.n_batch < GGML_KQ_MASK_PAD) {
|
||||
LLAMA_LOG_WARN("%s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n", __func__, GGML_KQ_MASK_PAD);
|
||||
cparams.n_batch = GGML_KQ_MASK_PAD;
|
||||
|
|
@ -439,26 +438,12 @@ llama_memory_t llama_context::get_memory() const {
|
|||
return memory.get();
|
||||
}
|
||||
|
||||
// deprecated
|
||||
void llama_context::kv_self_defrag_sched() {
|
||||
if (!memory) {
|
||||
return;
|
||||
}
|
||||
|
||||
memory_force_optimize = true;
|
||||
}
|
||||
|
||||
// deprecated
|
||||
bool llama_context::kv_self_update(bool optimize) {
|
||||
bool llama_context::memory_update(bool optimize) {
|
||||
if (!memory) {
|
||||
return false;
|
||||
}
|
||||
|
||||
{
|
||||
// TODO: remove in the future
|
||||
optimize |= memory_force_optimize;
|
||||
memory_force_optimize = false;
|
||||
|
||||
const auto mctx = memory->init_update(this, optimize);
|
||||
switch (mctx->get_status()) {
|
||||
case LLAMA_MEMORY_STATUS_SUCCESS:
|
||||
|
|
@ -992,8 +977,8 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||
|
||||
bool did_optimize = false;
|
||||
|
||||
// handle any pending defrags/shifts
|
||||
kv_self_update(false);
|
||||
// handle any pending shifts/copies
|
||||
memory_update(false);
|
||||
|
||||
llama_memory_context_ptr mctx;
|
||||
|
||||
|
|
@ -1018,7 +1003,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||
if (!did_optimize) {
|
||||
did_optimize = true;
|
||||
|
||||
if (kv_self_update(true)) {
|
||||
if (memory_update(true)) {
|
||||
LLAMA_LOG_DEBUG("%s: retrying batch size %d after cache optimization\n", __func__, balloc->get_n_tokens());
|
||||
|
||||
continue;
|
||||
|
|
@ -2338,16 +2323,6 @@ const llama_model * llama_get_model(const llama_context * ctx) {
|
|||
return &ctx->get_model();
|
||||
}
|
||||
|
||||
// deprecated
|
||||
llama_kv_cache * llama_get_kv_self(llama_context * ctx) {
|
||||
return dynamic_cast<llama_kv_cache *>(ctx->get_memory());
|
||||
}
|
||||
|
||||
// deprecated
|
||||
void llama_kv_self_update(llama_context * ctx) {
|
||||
ctx->kv_self_update(false);
|
||||
}
|
||||
|
||||
enum llama_pooling_type llama_pooling_type(const llama_context * ctx) {
|
||||
return ctx->pooling_type();
|
||||
}
|
||||
|
|
@ -2565,168 +2540,6 @@ bool llama_memory_can_shift(llama_memory_t mem) {
|
|||
return mem->get_can_shift();
|
||||
}
|
||||
|
||||
//
|
||||
// kv cache
|
||||
//
|
||||
|
||||
// deprecated
|
||||
int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
|
||||
const auto * kv = llama_get_memory(ctx);
|
||||
if (!kv) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
int32_t res = 0;
|
||||
|
||||
for (uint32_t s = 0; s < ctx->get_cparams().n_seq_max; s++) {
|
||||
const llama_pos p0 = kv->seq_pos_min(s);
|
||||
const llama_pos p1 = kv->seq_pos_max(s);
|
||||
|
||||
if (p0 >= 0) {
|
||||
res += (p1 - p0) + 1;
|
||||
}
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
// deprecated
|
||||
// note: this is the same as above - will be removed anyway, so it's ok
|
||||
int32_t llama_kv_self_used_cells(const llama_context * ctx) {
|
||||
const auto * kv = llama_get_memory(ctx);
|
||||
if (!kv) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
int32_t res = 0;
|
||||
|
||||
for (uint32_t s = 0; s < ctx->get_cparams().n_seq_max; s++) {
|
||||
const llama_pos p0 = kv->seq_pos_min(s);
|
||||
const llama_pos p1 = kv->seq_pos_max(s);
|
||||
|
||||
if (p0 >= 0) {
|
||||
res += (p1 - p0) + 1;
|
||||
}
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
// deprecated
|
||||
void llama_kv_self_clear(llama_context * ctx) {
|
||||
auto * kv = llama_get_memory(ctx);
|
||||
if (!kv) {
|
||||
return;
|
||||
}
|
||||
|
||||
llama_memory_clear(kv, true);
|
||||
}
|
||||
|
||||
// deprecated
|
||||
bool llama_kv_self_seq_rm(
|
||||
llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1) {
|
||||
auto * kv = llama_get_memory(ctx);
|
||||
if (!kv) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return llama_memory_seq_rm(kv, seq_id, p0, p1);
|
||||
}
|
||||
|
||||
// deprecated
|
||||
void llama_kv_self_seq_cp(
|
||||
llama_context * ctx,
|
||||
llama_seq_id seq_id_src,
|
||||
llama_seq_id seq_id_dst,
|
||||
llama_pos p0,
|
||||
llama_pos p1) {
|
||||
auto * kv = llama_get_memory(ctx);
|
||||
if (!kv) {
|
||||
return;
|
||||
}
|
||||
|
||||
llama_memory_seq_cp(kv, seq_id_src, seq_id_dst, p0, p1);
|
||||
}
|
||||
|
||||
// deprecated
|
||||
void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
|
||||
auto * kv = llama_get_memory(ctx);
|
||||
if (!kv) {
|
||||
return;
|
||||
}
|
||||
|
||||
llama_memory_seq_keep(kv, seq_id);
|
||||
}
|
||||
|
||||
// deprecated
|
||||
void llama_kv_self_seq_add(
|
||||
llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1,
|
||||
llama_pos delta) {
|
||||
auto * kv = llama_get_memory(ctx);
|
||||
if (!kv) {
|
||||
return;
|
||||
}
|
||||
|
||||
llama_memory_seq_add(kv, seq_id, p0, p1, delta);
|
||||
}
|
||||
|
||||
// deprecated
|
||||
void llama_kv_self_seq_div(
|
||||
llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1,
|
||||
int d) {
|
||||
auto * kv = llama_get_memory(ctx);
|
||||
if (!kv) {
|
||||
return;
|
||||
}
|
||||
|
||||
llama_memory_seq_div(kv, seq_id, p0, p1, d);
|
||||
}
|
||||
|
||||
// deprecated
|
||||
llama_pos llama_kv_self_seq_pos_min(llama_context * ctx, llama_seq_id seq_id) {
|
||||
auto * kv = llama_get_memory(ctx);
|
||||
if (!kv) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
return llama_memory_seq_pos_min(kv, seq_id);
|
||||
}
|
||||
|
||||
// deprecated
|
||||
llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
|
||||
auto * kv = llama_get_memory(ctx);
|
||||
if (!kv) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
return llama_memory_seq_pos_max(kv, seq_id);
|
||||
}
|
||||
|
||||
// deprecated
|
||||
void llama_kv_self_defrag(llama_context * ctx) {
|
||||
// force defrag
|
||||
ctx->kv_self_defrag_sched();
|
||||
}
|
||||
|
||||
// deprecated
|
||||
bool llama_kv_self_can_shift(const llama_context * ctx) {
|
||||
auto * kv = llama_get_memory(ctx);
|
||||
if (!kv) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return llama_memory_can_shift(kv);
|
||||
}
|
||||
|
||||
// llama state API
|
||||
|
||||
// deprecated
|
||||
|
|
|
|||
|
|
@ -46,10 +46,8 @@ struct llama_context {
|
|||
|
||||
llama_memory_t get_memory() const;
|
||||
|
||||
// return true of the KV cache was updated
|
||||
// TODO: remove
|
||||
bool kv_self_update(bool optimize);
|
||||
void kv_self_defrag_sched();
|
||||
// return true if the memory was updated
|
||||
bool memory_update(bool optimize);
|
||||
|
||||
enum llama_pooling_type pooling_type() const;
|
||||
|
||||
|
|
@ -230,9 +228,6 @@ private:
|
|||
|
||||
std::unique_ptr<llama_memory_i> memory;
|
||||
|
||||
// TODO: temporary, until the llama_kv_self_defrag() API is removed
|
||||
bool memory_force_optimize = false;
|
||||
|
||||
// decode output (2-dimensional array: [n_outputs][n_vocab])
|
||||
size_t logits_size = 0; // capacity (of floats) for logits
|
||||
float * logits = nullptr;
|
||||
|
|
|
|||
|
|
@ -24,7 +24,6 @@ struct llama_cparams {
|
|||
float yarn_attn_factor;
|
||||
float yarn_beta_fast;
|
||||
float yarn_beta_slow;
|
||||
float defrag_thold;
|
||||
|
||||
bool embeddings;
|
||||
bool causal_attn;
|
||||
|
|
|
|||
|
|
@ -4,8 +4,8 @@
|
|||
#include "llama-batch.h"
|
||||
#include "llama-cparams.h"
|
||||
|
||||
#include "llama-kv-cache-unified.h"
|
||||
#include "llama-kv-cache-unified-iswa.h"
|
||||
#include "llama-kv-cache.h"
|
||||
#include "llama-kv-cache-iswa.h"
|
||||
#include "llama-memory-hybrid.h"
|
||||
#include "llama-memory-recurrent.h"
|
||||
|
||||
|
|
@ -277,7 +277,7 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
|
|||
for (int s = 0; s < ubatch->n_seq_id[i0]; ++s) {
|
||||
const llama_seq_id s0 = ubatch->seq_id[i0][0];
|
||||
|
||||
// TODO: reimplement this like in llama_kv_cache_unified
|
||||
// TODO: reimplement this like in llama_kv_cache
|
||||
if (s0 == s1 && (!cparams.causal_attn || ubatch->pos[i0] <= ubatch->pos[i1])) {
|
||||
if (hparams.use_alibi) {
|
||||
f = -std::abs(ubatch->pos[i0] - ubatch->pos[i1]);
|
||||
|
|
@ -294,15 +294,15 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
|
|||
}
|
||||
}
|
||||
|
||||
void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
|
||||
void llm_graph_input_attn_kv::set_input(const llama_ubatch * ubatch) {
|
||||
mctx->set_input_k_idxs(self_k_idxs, ubatch);
|
||||
mctx->set_input_v_idxs(self_v_idxs, ubatch);
|
||||
|
||||
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
||||
}
|
||||
|
||||
bool llm_graph_input_attn_kv_unified::can_reuse(const llm_graph_params & params) {
|
||||
const auto * mctx = static_cast<const llama_kv_cache_unified_context *>(params.mctx);
|
||||
bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) {
|
||||
const auto * mctx = static_cast<const llama_kv_cache_context *>(params.mctx);
|
||||
|
||||
this->mctx = mctx;
|
||||
|
||||
|
|
@ -319,7 +319,7 @@ bool llm_graph_input_attn_kv_unified::can_reuse(const llm_graph_params & params)
|
|||
return res;
|
||||
}
|
||||
|
||||
void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
|
||||
void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) {
|
||||
mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
|
||||
mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
|
||||
|
||||
|
|
@ -331,8 +331,8 @@ void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch
|
|||
mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
|
||||
}
|
||||
|
||||
bool llm_graph_input_attn_kv_unified_iswa::can_reuse(const llm_graph_params & params) {
|
||||
const auto * mctx = static_cast<const llama_kv_cache_unified_iswa_context *>(params.mctx);
|
||||
bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
|
||||
const auto * mctx = static_cast<const llama_kv_cache_iswa_context *>(params.mctx);
|
||||
|
||||
this->mctx = mctx;
|
||||
|
||||
|
|
@ -1186,7 +1186,7 @@ ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const {
|
|||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
|
||||
const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
|
||||
const auto * mctx_cur = static_cast<const llama_kv_cache_context *>(mctx);
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, mctx_cur);
|
||||
|
||||
|
|
@ -1223,8 +1223,8 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
|||
ggml_tensor * v,
|
||||
ggml_tensor * kq_b,
|
||||
ggml_tensor * kq_mask,
|
||||
ggml_tensor * v_mla,
|
||||
ggml_tensor * sinks,
|
||||
ggml_tensor * v_mla,
|
||||
float kq_scale) const {
|
||||
const bool v_trans = v->nb[1] > v->nb[2];
|
||||
|
||||
|
|
@ -1360,6 +1360,7 @@ ggml_tensor * llm_graph_context::build_attn(
|
|||
ggml_tensor * k_cur,
|
||||
ggml_tensor * v_cur,
|
||||
ggml_tensor * kq_b,
|
||||
ggml_tensor * sinks,
|
||||
ggml_tensor * v_mla,
|
||||
float kq_scale,
|
||||
int il) const {
|
||||
|
|
@ -1381,7 +1382,7 @@ ggml_tensor * llm_graph_context::build_attn(
|
|||
ggml_tensor * k = k_cur;
|
||||
ggml_tensor * v = v_cur;
|
||||
|
||||
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, nullptr, kq_scale);
|
||||
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale);
|
||||
cb(cur, "kqv_out", il);
|
||||
|
||||
if (wo) {
|
||||
|
|
@ -1399,17 +1400,17 @@ ggml_tensor * llm_graph_context::build_attn(
|
|||
return cur;
|
||||
}
|
||||
|
||||
static std::unique_ptr<llm_graph_input_attn_kv_unified> build_attn_inp_kv_unified_impl(
|
||||
static std::unique_ptr<llm_graph_input_attn_kv> build_attn_inp_kv_impl(
|
||||
ggml_context * ctx0,
|
||||
const llama_ubatch & ubatch,
|
||||
const llama_hparams & hparams,
|
||||
const llama_cparams & cparams,
|
||||
const llama_kv_cache_unified_context * mctx_cur) {
|
||||
const llama_kv_cache_context * mctx_cur) {
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, mctx_cur);
|
||||
auto inp = std::make_unique<llm_graph_input_attn_kv>(hparams, cparams, mctx_cur);
|
||||
|
||||
{
|
||||
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
|
||||
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_iswa for SWA");
|
||||
|
||||
const auto n_kv = mctx_cur->get_n_kv();
|
||||
const auto n_tokens = ubatch.n_tokens;
|
||||
|
|
@ -1427,22 +1428,23 @@ static std::unique_ptr<llm_graph_input_attn_kv_unified> build_attn_inp_kv_unifie
|
|||
return inp;
|
||||
}
|
||||
|
||||
llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const {
|
||||
const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
|
||||
llm_graph_input_attn_kv * llm_graph_context::build_attn_inp_kv() const {
|
||||
const auto * mctx_cur = static_cast<const llama_kv_cache_context *>(mctx);
|
||||
|
||||
auto inp = build_attn_inp_kv_unified_impl(ctx0, ubatch, hparams, cparams, mctx_cur);
|
||||
auto inp = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur);
|
||||
|
||||
return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp));
|
||||
return (llm_graph_input_attn_kv *) res->add_input(std::move(inp));
|
||||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_attn(
|
||||
llm_graph_input_attn_kv_unified * inp,
|
||||
llm_graph_input_attn_kv * inp,
|
||||
ggml_tensor * wo,
|
||||
ggml_tensor * wo_b,
|
||||
ggml_tensor * q_cur,
|
||||
ggml_tensor * k_cur,
|
||||
ggml_tensor * v_cur,
|
||||
ggml_tensor * kq_b,
|
||||
ggml_tensor * sinks,
|
||||
ggml_tensor * v_mla,
|
||||
float kq_scale,
|
||||
int il) const {
|
||||
|
|
@ -1469,7 +1471,7 @@ ggml_tensor * llm_graph_context::build_attn(
|
|||
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
|
||||
ggml_tensor * v = mctx_cur->get_v(ctx0, il);
|
||||
|
||||
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, nullptr, kq_scale);
|
||||
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale);
|
||||
cb(cur, "kqv_out", il);
|
||||
|
||||
if (wo) {
|
||||
|
|
@ -1488,40 +1490,15 @@ ggml_tensor * llm_graph_context::build_attn(
|
|||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_attn(
|
||||
llm_graph_input_attn_kv_unified_iswa * inp,
|
||||
llm_graph_input_attn_kv_iswa * inp,
|
||||
ggml_tensor * wo,
|
||||
ggml_tensor * wo_b,
|
||||
ggml_tensor * q_cur,
|
||||
ggml_tensor * k_cur,
|
||||
ggml_tensor * v_cur,
|
||||
ggml_tensor * kq_b,
|
||||
ggml_tensor * v_mla,
|
||||
float kq_scale,
|
||||
int il) const {
|
||||
return build_attn_with_sinks(
|
||||
inp,
|
||||
wo,
|
||||
wo_b,
|
||||
q_cur,
|
||||
k_cur,
|
||||
v_cur,
|
||||
kq_b,
|
||||
v_mla,
|
||||
nullptr,
|
||||
kq_scale,
|
||||
il);
|
||||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_attn_with_sinks(
|
||||
llm_graph_input_attn_kv_unified_iswa * inp,
|
||||
ggml_tensor * wo,
|
||||
ggml_tensor * wo_b,
|
||||
ggml_tensor * q_cur,
|
||||
ggml_tensor * k_cur,
|
||||
ggml_tensor * v_cur,
|
||||
ggml_tensor * kq_b,
|
||||
ggml_tensor * v_mla,
|
||||
ggml_tensor * sinks,
|
||||
ggml_tensor * v_mla,
|
||||
float kq_scale,
|
||||
int il) const {
|
||||
// these nodes are added to the graph together so that they are not reordered
|
||||
|
|
@ -1561,7 +1538,7 @@ ggml_tensor * llm_graph_context::build_attn_with_sinks(
|
|||
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
|
||||
ggml_tensor * v = mctx_cur->get_v(ctx0, il);
|
||||
|
||||
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, sinks, kq_scale);
|
||||
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale);
|
||||
cb(cur, "kqv_out", il);
|
||||
|
||||
if (wo) {
|
||||
|
|
@ -1600,6 +1577,7 @@ ggml_tensor * llm_graph_context::build_attn(
|
|||
ggml_tensor * k_cur,
|
||||
ggml_tensor * v_cur,
|
||||
ggml_tensor * kq_b,
|
||||
ggml_tensor * sinks,
|
||||
ggml_tensor * v_mla,
|
||||
float kq_scale,
|
||||
int il) const {
|
||||
|
|
@ -1615,7 +1593,7 @@ ggml_tensor * llm_graph_context::build_attn(
|
|||
ggml_tensor * k = k_cur;
|
||||
ggml_tensor * v = v_cur;
|
||||
|
||||
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, nullptr, kq_scale);
|
||||
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale);
|
||||
cb(cur, "kqv_out", il);
|
||||
|
||||
if (wo) {
|
||||
|
|
@ -1636,10 +1614,10 @@ ggml_tensor * llm_graph_context::build_attn(
|
|||
// TODO: maybe separate the inner implementation into a separate function
|
||||
// like with the non-sliding window equivalent
|
||||
// once sliding-window hybrid caches are a thing.
|
||||
llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
|
||||
const auto * mctx_cur = static_cast<const llama_kv_cache_unified_iswa_context *>(mctx);
|
||||
llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const {
|
||||
const auto * mctx_cur = static_cast<const llama_kv_cache_iswa_context *>(mctx);
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
|
||||
auto inp = std::make_unique<llm_graph_input_attn_kv_iswa>(hparams, cparams, mctx_cur);
|
||||
|
||||
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
|
||||
|
||||
|
|
@ -1656,7 +1634,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
|
|||
}
|
||||
|
||||
{
|
||||
GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
|
||||
GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache for non-SWA");
|
||||
|
||||
const auto n_kv = mctx_cur->get_swa()->get_n_kv();
|
||||
|
||||
|
|
@ -1669,7 +1647,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
|
|||
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
|
||||
}
|
||||
|
||||
return (llm_graph_input_attn_kv_unified_iswa *) res->add_input(std::move(inp));
|
||||
return (llm_graph_input_attn_kv_iswa *) res->add_input(std::move(inp));
|
||||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_rs(
|
||||
|
|
@ -1792,7 +1770,7 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
|
|||
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
|
||||
|
||||
auto inp_rs = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr());
|
||||
auto inp_attn = build_attn_inp_kv_unified_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn());
|
||||
auto inp_attn = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn());
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_mem_hybrid>(std::move(inp_attn), std::move(inp_rs), mctx_cur);
|
||||
|
||||
|
|
|
|||
|
|
@ -19,8 +19,8 @@ struct llama_cparams;
|
|||
|
||||
struct llama_memory_context_i;
|
||||
|
||||
class llama_kv_cache_unified_context;
|
||||
class llama_kv_cache_unified_iswa_context;
|
||||
class llama_kv_cache_context;
|
||||
class llama_kv_cache_iswa_context;
|
||||
class llama_memory_recurrent_context;
|
||||
class llama_memory_hybrid_context;
|
||||
|
||||
|
|
@ -152,7 +152,7 @@ class llm_graph_input_pos_bucket_kv : public llm_graph_input_i {
|
|||
public:
|
||||
llm_graph_input_pos_bucket_kv(
|
||||
const llama_hparams & hparams,
|
||||
const llama_kv_cache_unified_context * mctx) : hparams(hparams), mctx(mctx) {}
|
||||
const llama_kv_cache_context * mctx) : hparams(hparams), mctx(mctx) {}
|
||||
virtual ~llm_graph_input_pos_bucket_kv() = default;
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
|
|
@ -161,7 +161,7 @@ public:
|
|||
|
||||
const llama_hparams hparams;
|
||||
|
||||
const llama_kv_cache_unified_context * mctx;
|
||||
const llama_kv_cache_context * mctx;
|
||||
};
|
||||
|
||||
class llm_graph_input_out_ids : public llm_graph_input_i {
|
||||
|
|
@ -257,17 +257,17 @@ public:
|
|||
const llama_cparams cparams;
|
||||
};
|
||||
|
||||
class llm_graph_input_attn_kv_unified : public llm_graph_input_i {
|
||||
class llm_graph_input_attn_kv : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_attn_kv_unified(
|
||||
llm_graph_input_attn_kv(
|
||||
const llama_hparams & hparams,
|
||||
const llama_cparams & cparams,
|
||||
const llama_kv_cache_unified_context * mctx) :
|
||||
const llama_kv_cache_context * mctx) :
|
||||
hparams(hparams),
|
||||
cparams(cparams),
|
||||
mctx(mctx) {
|
||||
}
|
||||
~llm_graph_input_attn_kv_unified() = default;
|
||||
~llm_graph_input_attn_kv() = default;
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
|
||||
|
|
@ -290,20 +290,20 @@ public:
|
|||
const llama_hparams hparams;
|
||||
const llama_cparams cparams;
|
||||
|
||||
const llama_kv_cache_unified_context * mctx;
|
||||
const llama_kv_cache_context * mctx;
|
||||
};
|
||||
|
||||
class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
|
||||
class llm_graph_input_attn_kv_iswa : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_attn_kv_unified_iswa(
|
||||
llm_graph_input_attn_kv_iswa(
|
||||
const llama_hparams & hparams,
|
||||
const llama_cparams & cparams,
|
||||
const llama_kv_cache_unified_iswa_context * mctx) :
|
||||
const llama_kv_cache_iswa_context * mctx) :
|
||||
hparams(hparams),
|
||||
cparams(cparams),
|
||||
mctx(mctx) {
|
||||
}
|
||||
~llm_graph_input_attn_kv_unified_iswa() = default;
|
||||
~llm_graph_input_attn_kv_iswa() = default;
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
|
||||
|
|
@ -330,7 +330,7 @@ public:
|
|||
const llama_hparams hparams;
|
||||
const llama_cparams cparams;
|
||||
|
||||
const llama_kv_cache_unified_iswa_context * mctx;
|
||||
const llama_kv_cache_iswa_context * mctx;
|
||||
};
|
||||
|
||||
class llm_graph_input_attn_cross : public llm_graph_input_i {
|
||||
|
|
@ -351,7 +351,7 @@ public:
|
|||
class llm_graph_input_mem_hybrid : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_mem_hybrid(
|
||||
std::unique_ptr<llm_graph_input_attn_kv_unified> inp_attn,
|
||||
std::unique_ptr<llm_graph_input_attn_kv> inp_attn,
|
||||
std::unique_ptr<llm_graph_input_rs> inp_rs,
|
||||
const llama_memory_hybrid_context * mctx) :
|
||||
inp_attn(std::move(inp_attn)),
|
||||
|
|
@ -361,11 +361,11 @@ public:
|
|||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
|
||||
std::unique_ptr<llm_graph_input_attn_kv_unified> inp_attn;
|
||||
std::unique_ptr<llm_graph_input_rs> inp_rs;
|
||||
std::unique_ptr<llm_graph_input_attn_kv> inp_attn;
|
||||
std::unique_ptr<llm_graph_input_rs> inp_rs;
|
||||
|
||||
llm_graph_input_attn_kv_unified * get_attn() const { return inp_attn.get(); }
|
||||
llm_graph_input_rs * get_recr() const { return inp_rs.get(); }
|
||||
llm_graph_input_attn_kv * get_attn() const { return inp_attn.get(); }
|
||||
llm_graph_input_rs * get_recr() const { return inp_rs.get(); }
|
||||
|
||||
const llama_memory_hybrid_context * mctx;
|
||||
};
|
||||
|
|
@ -680,14 +680,14 @@ struct llm_graph_context {
|
|||
//
|
||||
|
||||
ggml_tensor * build_attn_mha(
|
||||
ggml_tensor * q, // [n_embd_head_q, n_head_q, n_tokens]
|
||||
ggml_tensor * k, // [n_embd_head_k, n_head_k, n_tokens]
|
||||
ggml_tensor * v, // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false)
|
||||
ggml_tensor * kq_b,
|
||||
ggml_tensor * kq_mask,
|
||||
ggml_tensor * sinks,
|
||||
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
|
||||
float kq_scale) const;
|
||||
ggml_tensor * q, // [n_embd_head_q, n_head_q, n_tokens]
|
||||
ggml_tensor * k, // [n_embd_head_k, n_head_k, n_tokens]
|
||||
ggml_tensor * v, // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false)
|
||||
ggml_tensor * kq_b,
|
||||
ggml_tensor * kq_mask,
|
||||
ggml_tensor * sinks, // [n_head_q]
|
||||
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
|
||||
float kq_scale) const;
|
||||
|
||||
llm_graph_input_attn_no_cache * build_attn_inp_no_cache() const;
|
||||
|
||||
|
|
@ -699,50 +699,39 @@ struct llm_graph_context {
|
|||
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
|
||||
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
|
||||
ggml_tensor * kq_b,
|
||||
ggml_tensor * sinks, // [n_head_q]
|
||||
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
|
||||
float kq_scale,
|
||||
int il) const;
|
||||
|
||||
llm_graph_input_attn_kv_unified * build_attn_inp_kv_unified() const;
|
||||
llm_graph_input_attn_kv * build_attn_inp_kv() const;
|
||||
|
||||
ggml_tensor * build_attn(
|
||||
llm_graph_input_attn_kv_unified * inp,
|
||||
llm_graph_input_attn_kv * inp,
|
||||
ggml_tensor * wo,
|
||||
ggml_tensor * wo_b,
|
||||
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
|
||||
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
|
||||
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
|
||||
ggml_tensor * kq_b,
|
||||
ggml_tensor * sinks, // [n_head_q]
|
||||
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
|
||||
float kq_scale,
|
||||
int il) const;
|
||||
|
||||
llm_graph_input_attn_kv_unified_iswa * build_attn_inp_kv_unified_iswa() const;
|
||||
llm_graph_input_attn_kv_iswa * build_attn_inp_kv_iswa() const;
|
||||
|
||||
// note: if k_cur or v_cur are not provided, they will not be stored in the memory
|
||||
ggml_tensor * build_attn(
|
||||
llm_graph_input_attn_kv_unified_iswa * inp,
|
||||
llm_graph_input_attn_kv_iswa * inp,
|
||||
ggml_tensor * wo,
|
||||
ggml_tensor * wo_b,
|
||||
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
|
||||
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] optional
|
||||
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] optional
|
||||
ggml_tensor * kq_b,
|
||||
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
|
||||
float kq_scale,
|
||||
int il) const;
|
||||
|
||||
// TODO: temporary to keep the diff small. after the code is public will refactor to simplify this
|
||||
ggml_tensor * build_attn_with_sinks(
|
||||
llm_graph_input_attn_kv_unified_iswa * inp,
|
||||
ggml_tensor * wo,
|
||||
ggml_tensor * wo_b,
|
||||
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
|
||||
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] optional
|
||||
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] optional
|
||||
ggml_tensor * kq_b,
|
||||
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
|
||||
ggml_tensor * sinks, // [n_head_q]
|
||||
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
|
||||
float kq_scale,
|
||||
int il) const;
|
||||
|
||||
|
|
@ -756,6 +745,7 @@ struct llm_graph_context {
|
|||
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
|
||||
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
|
||||
ggml_tensor * kq_b,
|
||||
ggml_tensor * sinks, // [n_head_q]
|
||||
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
|
||||
float kq_scale,
|
||||
int il) const;
|
||||
|
|
@ -765,7 +755,7 @@ struct llm_graph_context {
|
|||
//
|
||||
|
||||
// TODO: move this implementation to llama_memory_recurrent.
|
||||
// this is analogous to llama_kv_cache_unified::cpy_k / cpy_v
|
||||
// this is analogous to llama_kv_cache::cpy_k / cpy_v
|
||||
// when moving, avoid passing `ggml_cgraph` - only pass `ggml_context`. would likely need to split the
|
||||
// implementation in 2 separate methods. the goal is to avoid calling `ggml_build_forward_expand` in
|
||||
// `llama_memory_recurrent`
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
#include "llama-kv-cache-unified-iswa.h"
|
||||
#include "llama-kv-cache-iswa.h"
|
||||
|
||||
#include "llama-impl.h"
|
||||
#include "llama-batch.h"
|
||||
|
|
@ -8,10 +8,10 @@
|
|||
#include <cassert>
|
||||
|
||||
//
|
||||
// llama_kv_cache_unified_iswa
|
||||
// llama_kv_cache_iswa
|
||||
//
|
||||
|
||||
llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
|
||||
llama_kv_cache_iswa::llama_kv_cache_iswa(
|
||||
const llama_model & model,
|
||||
ggml_type type_k,
|
||||
ggml_type type_v,
|
||||
|
|
@ -23,8 +23,8 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
|
|||
uint32_t n_seq_max,
|
||||
uint32_t n_ubatch,
|
||||
uint32_t n_pad) : hparams(model.hparams), unified(unified) {
|
||||
llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
|
||||
llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); };
|
||||
llama_kv_cache::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
|
||||
llama_kv_cache::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); };
|
||||
|
||||
const uint32_t size_base = kv_size;
|
||||
|
||||
|
|
@ -44,25 +44,25 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
|
|||
|
||||
LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base);
|
||||
|
||||
kv_base = std::make_unique<llama_kv_cache_unified>(
|
||||
kv_base = std::make_unique<llama_kv_cache>(
|
||||
model, std::move(filter_base), type_k, type_v,
|
||||
v_trans, offload, unified, size_base, n_seq_max, n_pad,
|
||||
0, LLAMA_SWA_TYPE_NONE);
|
||||
|
||||
LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
|
||||
|
||||
kv_swa = std::make_unique<llama_kv_cache_unified>(
|
||||
kv_swa = std::make_unique<llama_kv_cache>(
|
||||
model, std::move(filter_swa), type_k, type_v,
|
||||
v_trans, offload, unified, size_swa, n_seq_max, n_pad,
|
||||
hparams.n_swa, hparams.swa_type);
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified_iswa::clear(bool data) {
|
||||
void llama_kv_cache_iswa::clear(bool data) {
|
||||
kv_base->clear(data);
|
||||
kv_swa ->clear(data);
|
||||
}
|
||||
|
||||
bool llama_kv_cache_unified_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
||||
bool llama_kv_cache_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
||||
bool res = true;
|
||||
|
||||
res = res & kv_base->seq_rm(seq_id, p0, p1);
|
||||
|
|
@ -71,36 +71,36 @@ bool llama_kv_cache_unified_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llam
|
|||
return res;
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified_iswa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
|
||||
void llama_kv_cache_iswa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
|
||||
kv_base->seq_cp(seq_id_src, seq_id_dst, p0, p1);
|
||||
kv_swa ->seq_cp(seq_id_src, seq_id_dst, p0, p1);
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified_iswa::seq_keep(llama_seq_id seq_id) {
|
||||
void llama_kv_cache_iswa::seq_keep(llama_seq_id seq_id) {
|
||||
kv_base->seq_keep(seq_id);
|
||||
kv_swa ->seq_keep(seq_id);
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
|
||||
void llama_kv_cache_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
|
||||
kv_base->seq_add(seq_id, p0, p1, shift);
|
||||
kv_swa ->seq_add(seq_id, p0, p1, shift);
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
||||
void llama_kv_cache_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
||||
kv_base->seq_div(seq_id, p0, p1, d);
|
||||
kv_swa ->seq_div(seq_id, p0, p1, d);
|
||||
}
|
||||
|
||||
llama_pos llama_kv_cache_unified_iswa::seq_pos_min(llama_seq_id seq_id) const {
|
||||
llama_pos llama_kv_cache_iswa::seq_pos_min(llama_seq_id seq_id) const {
|
||||
// the base cache is a superset of the SWA cache, so we can just check the SWA cache
|
||||
return kv_swa->seq_pos_min(seq_id);
|
||||
}
|
||||
|
||||
llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
|
||||
llama_pos llama_kv_cache_iswa::seq_pos_max(llama_seq_id seq_id) const {
|
||||
return kv_swa->seq_pos_max(seq_id);
|
||||
}
|
||||
|
||||
llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
|
||||
llama_memory_context_ptr llama_kv_cache_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
|
||||
GGML_UNUSED(embd_all);
|
||||
|
||||
// first try simple split
|
||||
|
|
@ -140,7 +140,7 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
|
|||
|
||||
assert(sinfos_base.size() == sinfos_swa.size());
|
||||
|
||||
return std::make_unique<llama_kv_cache_unified_iswa_context>(
|
||||
return std::make_unique<llama_kv_cache_iswa_context>(
|
||||
this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
|
||||
} while (false);
|
||||
|
||||
|
|
@ -176,29 +176,29 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
|
|||
|
||||
assert(sinfos_base.size() == sinfos_swa.size());
|
||||
|
||||
return std::make_unique<llama_kv_cache_unified_iswa_context>(
|
||||
return std::make_unique<llama_kv_cache_iswa_context>(
|
||||
this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
|
||||
} while (false);
|
||||
|
||||
// TODO: if we fail again, we should attempt different splitting strategies
|
||||
// but to do that properly, we first have to refactor the batches to be more flexible
|
||||
|
||||
return std::make_unique<llama_kv_cache_unified_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
||||
return std::make_unique<llama_kv_cache_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
||||
}
|
||||
|
||||
llama_memory_context_ptr llama_kv_cache_unified_iswa::init_full() {
|
||||
return std::make_unique<llama_kv_cache_unified_iswa_context>(this);
|
||||
llama_memory_context_ptr llama_kv_cache_iswa::init_full() {
|
||||
return std::make_unique<llama_kv_cache_iswa_context>(this);
|
||||
}
|
||||
|
||||
llama_memory_context_ptr llama_kv_cache_unified_iswa::init_update(llama_context * lctx, bool optimize) {
|
||||
return std::make_unique<llama_kv_cache_unified_iswa_context>(this, lctx, optimize);
|
||||
llama_memory_context_ptr llama_kv_cache_iswa::init_update(llama_context * lctx, bool optimize) {
|
||||
return std::make_unique<llama_kv_cache_iswa_context>(this, lctx, optimize);
|
||||
}
|
||||
|
||||
bool llama_kv_cache_unified_iswa::get_can_shift() const {
|
||||
bool llama_kv_cache_iswa::get_can_shift() const {
|
||||
return kv_base->get_size() == kv_swa->get_size();
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
|
||||
void llama_kv_cache_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
|
||||
if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) {
|
||||
kv_base->state_write(io, seq_id, flags);
|
||||
}
|
||||
|
|
@ -206,7 +206,7 @@ void llama_kv_cache_unified_iswa::state_write(llama_io_write_i & io, llama_seq_i
|
|||
kv_swa->state_write(io, seq_id, flags);
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
|
||||
void llama_kv_cache_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
|
||||
if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) {
|
||||
kv_base->state_read(io, seq_id, flags);
|
||||
}
|
||||
|
|
@ -214,29 +214,29 @@ void llama_kv_cache_unified_iswa::state_read(llama_io_read_i & io, llama_seq_id
|
|||
kv_swa->state_read(io, seq_id, flags);
|
||||
}
|
||||
|
||||
llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_base() const {
|
||||
llama_kv_cache * llama_kv_cache_iswa::get_base() const {
|
||||
return kv_base.get();
|
||||
}
|
||||
|
||||
llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const {
|
||||
llama_kv_cache * llama_kv_cache_iswa::get_swa() const {
|
||||
return kv_swa.get();
|
||||
}
|
||||
|
||||
//
|
||||
// llama_kv_cache_unified_iswa_context
|
||||
// llama_kv_cache_iswa_context
|
||||
//
|
||||
|
||||
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(llama_memory_status status) : status(status) {}
|
||||
llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(llama_memory_status status) : status(status) {}
|
||||
|
||||
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
|
||||
llama_kv_cache_unified_iswa * kv) :
|
||||
llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(
|
||||
llama_kv_cache_iswa * kv) :
|
||||
ctx_base(kv->get_base()->init_full()),
|
||||
ctx_swa (kv->get_swa ()->init_full()),
|
||||
status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
|
||||
}
|
||||
|
||||
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
|
||||
llama_kv_cache_unified_iswa * kv,
|
||||
llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(
|
||||
llama_kv_cache_iswa * kv,
|
||||
llama_context * lctx,
|
||||
bool optimize) :
|
||||
ctx_base(kv->get_base()->init_update(lctx, optimize)),
|
||||
|
|
@ -244,21 +244,21 @@ llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
|
|||
status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
|
||||
}
|
||||
|
||||
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
|
||||
llama_kv_cache_unified_iswa * kv,
|
||||
llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(
|
||||
llama_kv_cache_iswa * kv,
|
||||
slot_info_vec_t sinfos_base,
|
||||
slot_info_vec_t sinfos_swa,
|
||||
std::vector<llama_ubatch> ubatches) :
|
||||
ubatches(std::move(ubatches)),
|
||||
// note: here we copy the ubatches. not sure if this is ideal
|
||||
ctx_base(new llama_kv_cache_unified_context(kv->get_base(), std::move(sinfos_base), this->ubatches)),
|
||||
ctx_swa (new llama_kv_cache_unified_context(kv->get_swa (), std::move(sinfos_swa), this->ubatches)),
|
||||
ctx_base(new llama_kv_cache_context(kv->get_base(), std::move(sinfos_base), this->ubatches)),
|
||||
ctx_swa (new llama_kv_cache_context(kv->get_swa (), std::move(sinfos_swa), this->ubatches)),
|
||||
status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
|
||||
}
|
||||
|
||||
llama_kv_cache_unified_iswa_context:: ~llama_kv_cache_unified_iswa_context() = default;
|
||||
llama_kv_cache_iswa_context:: ~llama_kv_cache_iswa_context() = default;
|
||||
|
||||
bool llama_kv_cache_unified_iswa_context::next() {
|
||||
bool llama_kv_cache_iswa_context::next() {
|
||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||
|
||||
ctx_base->next();
|
||||
|
|
@ -271,7 +271,7 @@ bool llama_kv_cache_unified_iswa_context::next() {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool llama_kv_cache_unified_iswa_context::apply() {
|
||||
bool llama_kv_cache_iswa_context::apply() {
|
||||
assert(!llama_memory_status_is_fail(status));
|
||||
|
||||
bool res = true;
|
||||
|
|
@ -282,24 +282,24 @@ bool llama_kv_cache_unified_iswa_context::apply() {
|
|||
return res;
|
||||
}
|
||||
|
||||
llama_memory_status llama_kv_cache_unified_iswa_context::get_status() const {
|
||||
llama_memory_status llama_kv_cache_iswa_context::get_status() const {
|
||||
return status;
|
||||
}
|
||||
|
||||
const llama_ubatch & llama_kv_cache_unified_iswa_context::get_ubatch() const {
|
||||
const llama_ubatch & llama_kv_cache_iswa_context::get_ubatch() const {
|
||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||
|
||||
return ubatches[i_next];
|
||||
}
|
||||
|
||||
const llama_kv_cache_unified_context * llama_kv_cache_unified_iswa_context::get_base() const {
|
||||
const llama_kv_cache_context * llama_kv_cache_iswa_context::get_base() const {
|
||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||
|
||||
return static_cast<const llama_kv_cache_unified_context *>(ctx_base.get());
|
||||
return static_cast<const llama_kv_cache_context *>(ctx_base.get());
|
||||
}
|
||||
|
||||
const llama_kv_cache_unified_context * llama_kv_cache_unified_iswa_context::get_swa() const {
|
||||
const llama_kv_cache_context * llama_kv_cache_iswa_context::get_swa() const {
|
||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||
|
||||
return static_cast<const llama_kv_cache_unified_context *>(ctx_swa.get());
|
||||
return static_cast<const llama_kv_cache_context *>(ctx_swa.get());
|
||||
}
|
||||
|
|
@ -1,32 +1,32 @@
|
|||
#pragma once
|
||||
|
||||
#include "llama-kv-cache-unified.h"
|
||||
#include "llama-kv-cache.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
//
|
||||
// llama_kv_cache_unified_iswa
|
||||
// llama_kv_cache_iswa
|
||||
//
|
||||
|
||||
// utilizes two instances of llama_kv_cache_unified
|
||||
// utilizes two instances of llama_kv_cache
|
||||
// the first instance is for the non-SWA layers of the model and the second instance is for the SWA layers
|
||||
|
||||
class llama_kv_cache_unified_iswa : public llama_memory_i {
|
||||
class llama_kv_cache_iswa : public llama_memory_i {
|
||||
public:
|
||||
llama_kv_cache_unified_iswa(
|
||||
llama_kv_cache_iswa(
|
||||
const llama_model & model,
|
||||
ggml_type type_k,
|
||||
ggml_type type_v,
|
||||
bool v_trans,
|
||||
bool offload,
|
||||
bool swa_full,
|
||||
bool unified,
|
||||
bool ,
|
||||
uint32_t kv_size,
|
||||
uint32_t n_seq_max,
|
||||
uint32_t n_ubatch,
|
||||
uint32_t n_pad);
|
||||
|
||||
~llama_kv_cache_unified_iswa() = default;
|
||||
~llama_kv_cache_iswa() = default;
|
||||
|
||||
//
|
||||
// llama_memory_i
|
||||
|
|
@ -60,46 +60,46 @@ public:
|
|||
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
|
||||
|
||||
//
|
||||
// llama_kv_cache_unified_iswa specific API
|
||||
// llama_kv_cache_iswa specific API
|
||||
//
|
||||
|
||||
llama_kv_cache_unified * get_base() const;
|
||||
llama_kv_cache_unified * get_swa () const;
|
||||
llama_kv_cache * get_base() const;
|
||||
llama_kv_cache * get_swa () const;
|
||||
|
||||
private:
|
||||
const llama_hparams & hparams;
|
||||
|
||||
const bool unified;
|
||||
|
||||
std::unique_ptr<llama_kv_cache_unified> kv_base;
|
||||
std::unique_ptr<llama_kv_cache_unified> kv_swa;
|
||||
std::unique_ptr<llama_kv_cache> kv_base;
|
||||
std::unique_ptr<llama_kv_cache> kv_swa;
|
||||
};
|
||||
|
||||
class llama_kv_cache_unified_iswa_context : public llama_memory_context_i {
|
||||
class llama_kv_cache_iswa_context : public llama_memory_context_i {
|
||||
public:
|
||||
using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
|
||||
using slot_info_vec_t = llama_kv_cache::slot_info_vec_t;
|
||||
|
||||
// used for errors
|
||||
llama_kv_cache_unified_iswa_context(llama_memory_status status);
|
||||
llama_kv_cache_iswa_context(llama_memory_status status);
|
||||
|
||||
// used to create a full-cache context
|
||||
llama_kv_cache_unified_iswa_context(
|
||||
llama_kv_cache_unified_iswa * kv);
|
||||
llama_kv_cache_iswa_context(
|
||||
llama_kv_cache_iswa * kv);
|
||||
|
||||
// used to create an update context
|
||||
llama_kv_cache_unified_iswa_context(
|
||||
llama_kv_cache_unified_iswa * kv,
|
||||
llama_kv_cache_iswa_context(
|
||||
llama_kv_cache_iswa * kv,
|
||||
llama_context * lctx,
|
||||
bool optimize);
|
||||
|
||||
// used to create a batch processing context from a batch
|
||||
llama_kv_cache_unified_iswa_context(
|
||||
llama_kv_cache_unified_iswa * kv,
|
||||
llama_kv_cache_iswa_context(
|
||||
llama_kv_cache_iswa * kv,
|
||||
slot_info_vec_t sinfos_base,
|
||||
slot_info_vec_t sinfos_swa,
|
||||
std::vector<llama_ubatch> ubatches);
|
||||
|
||||
virtual ~llama_kv_cache_unified_iswa_context();
|
||||
virtual ~llama_kv_cache_iswa_context();
|
||||
|
||||
//
|
||||
// llama_memory_context_i
|
||||
|
|
@ -112,14 +112,14 @@ public:
|
|||
const llama_ubatch & get_ubatch() const override;
|
||||
|
||||
//
|
||||
// llama_kv_cache_unified_iswa_context specific API
|
||||
// llama_kv_cache_iswa_context specific API
|
||||
//
|
||||
|
||||
const llama_kv_cache_unified_context * get_base() const;
|
||||
const llama_kv_cache_unified_context * get_swa() const;
|
||||
const llama_kv_cache_context * get_base() const;
|
||||
const llama_kv_cache_context * get_swa() const;
|
||||
|
||||
private:
|
||||
//llama_kv_cache_unified_iswa * kv;
|
||||
//llama_kv_cache_iswa * kv;
|
||||
|
||||
// the index of the next ubatch to process
|
||||
size_t i_next = 0;
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
#include "llama-kv-cache-unified.h"
|
||||
#include "llama-kv-cache.h"
|
||||
|
||||
#include "llama-impl.h"
|
||||
#include "llama-io.h"
|
||||
|
|
@ -13,10 +13,10 @@
|
|||
#include <stdexcept>
|
||||
|
||||
//
|
||||
// llama_kv_cache_unified
|
||||
// llama_kv_cache
|
||||
//
|
||||
|
||||
llama_kv_cache_unified::llama_kv_cache_unified(
|
||||
llama_kv_cache::llama_kv_cache(
|
||||
const llama_model & model,
|
||||
layer_filter_cb && filter,
|
||||
ggml_type type_k,
|
||||
|
|
@ -209,7 +209,7 @@ llama_kv_cache_unified::llama_kv_cache_unified(
|
|||
}
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified::clear(bool data) {
|
||||
void llama_kv_cache::clear(bool data) {
|
||||
for (uint32_t s = 0; s < n_stream; ++s) {
|
||||
v_cells[s].reset();
|
||||
v_heads[s] = 0;
|
||||
|
|
@ -222,7 +222,7 @@ void llama_kv_cache_unified::clear(bool data) {
|
|||
}
|
||||
}
|
||||
|
||||
bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
||||
bool llama_kv_cache::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
||||
GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()));
|
||||
|
||||
if (p0 < 0) {
|
||||
|
|
@ -285,7 +285,7 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
|
|||
return true;
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
|
||||
void llama_kv_cache::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
|
||||
GGML_ASSERT(seq_id_src >= 0 && (size_t) seq_id_src < seq_to_stream.size());
|
||||
GGML_ASSERT(seq_id_dst >= 0 && (size_t) seq_id_dst < seq_to_stream.size());
|
||||
|
||||
|
|
@ -368,7 +368,7 @@ void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id
|
|||
//}
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) {
|
||||
void llama_kv_cache::seq_keep(llama_seq_id seq_id) {
|
||||
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
|
||||
|
||||
auto & cells = v_cells[seq_to_stream[seq_id]];
|
||||
|
|
@ -390,7 +390,7 @@ void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) {
|
|||
}
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
|
||||
void llama_kv_cache::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
|
||||
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
|
||||
|
||||
auto & cells = v_cells[seq_to_stream[seq_id]];
|
||||
|
|
@ -434,7 +434,7 @@ void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_po
|
|||
head = new_head != cells.size() ? new_head : 0;
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
||||
void llama_kv_cache::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
||||
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
|
||||
|
||||
auto & cells = v_cells[seq_to_stream[seq_id]];
|
||||
|
|
@ -467,7 +467,7 @@ void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_po
|
|||
}
|
||||
}
|
||||
|
||||
llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const {
|
||||
llama_pos llama_kv_cache::seq_pos_min(llama_seq_id seq_id) const {
|
||||
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
|
||||
|
||||
const auto & cells = v_cells[seq_to_stream[seq_id]];
|
||||
|
|
@ -475,7 +475,7 @@ llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const {
|
|||
return cells.seq_pos_min(seq_id);
|
||||
}
|
||||
|
||||
llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
|
||||
llama_pos llama_kv_cache::seq_pos_max(llama_seq_id seq_id) const {
|
||||
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
|
||||
|
||||
const auto & cells = v_cells[seq_to_stream[seq_id]];
|
||||
|
|
@ -483,7 +483,7 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
|
|||
return cells.seq_pos_max(seq_id);
|
||||
}
|
||||
|
||||
llama_memory_context_ptr llama_kv_cache_unified::init_batch(
|
||||
llama_memory_context_ptr llama_kv_cache::init_batch(
|
||||
llama_batch_allocr & balloc,
|
||||
uint32_t n_ubatch,
|
||||
bool embd_all) {
|
||||
|
|
@ -513,62 +513,34 @@ llama_memory_context_ptr llama_kv_cache_unified::init_batch(
|
|||
break;
|
||||
}
|
||||
|
||||
return std::make_unique<llama_kv_cache_unified_context>(
|
||||
return std::make_unique<llama_kv_cache_context>(
|
||||
this, std::move(sinfos), std::move(ubatches));
|
||||
} while (false);
|
||||
|
||||
return std::make_unique<llama_kv_cache_unified_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
||||
return std::make_unique<llama_kv_cache_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
||||
}
|
||||
|
||||
llama_memory_context_ptr llama_kv_cache_unified::init_full() {
|
||||
return std::make_unique<llama_kv_cache_unified_context>(this);
|
||||
llama_memory_context_ptr llama_kv_cache::init_full() {
|
||||
return std::make_unique<llama_kv_cache_context>(this);
|
||||
}
|
||||
|
||||
llama_memory_context_ptr llama_kv_cache_unified::init_update(llama_context * lctx, bool optimize) {
|
||||
llama_memory_context_ptr llama_kv_cache::init_update(llama_context * lctx, bool optimize) {
|
||||
GGML_UNUSED(optimize);
|
||||
|
||||
bool do_shift = get_has_shift();
|
||||
|
||||
defrag_info dinfo;
|
||||
|
||||
// see if we need to defrag
|
||||
if (n_stream == 1) {
|
||||
// note : for now do not consider defrag for n_stream > 1
|
||||
const auto & cells = v_cells[seq_to_stream[0]];
|
||||
|
||||
bool do_defrag = optimize;
|
||||
|
||||
const auto thold = lctx->get_cparams().defrag_thold;
|
||||
|
||||
if (!do_defrag && thold > 0.0f) {
|
||||
const auto n_kv = cells.used_max_p1();
|
||||
|
||||
// - do not defrag small contexts (i.e. < 2048 tokens)
|
||||
// - count the padding towards the number of used tokens
|
||||
const float fragmentation = n_kv >= 2048 ? std::max(0.0f, 1.0f - (float(cells.get_used() + n_pad)/n_kv)) : 0.0f;
|
||||
|
||||
if (fragmentation > thold) {
|
||||
LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation);
|
||||
|
||||
do_defrag = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (do_defrag) {
|
||||
dinfo = defrag_prepare(lctx->graph_max_nodes());
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_unique<llama_kv_cache_unified_context>(this, lctx, do_shift, std::move(dinfo), std::move(sc_info));
|
||||
return std::make_unique<llama_kv_cache_context>(this, lctx, do_shift, std::move(sc_info));
|
||||
}
|
||||
|
||||
llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
|
||||
llama_kv_cache_unified::slot_info_vec_t res;
|
||||
llama_kv_cache::slot_info_vec_t llama_kv_cache::prepare(const std::vector<llama_ubatch> & ubatches) {
|
||||
llama_kv_cache::slot_info_vec_t res;
|
||||
|
||||
struct state_t {
|
||||
slot_info sinfo; // slot info for the ubatch
|
||||
|
||||
std::vector<uint32_t> v_heads_old; // old positions of the heads, before placing the ubatch
|
||||
|
||||
std::vector<llama_kv_cells_unified> v_cells; // copy of the old cells, before placing the ubatch
|
||||
std::vector<llama_kv_cells> v_cells; // copy of the old cells, before placing the ubatch
|
||||
};
|
||||
|
||||
// remember the old state of the cells so we can restore it in the end
|
||||
|
|
@ -629,7 +601,7 @@ llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const st
|
|||
return res;
|
||||
}
|
||||
|
||||
bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const defrag_info & dinfo, const stream_copy_info & sc_info) {
|
||||
bool llama_kv_cache::update(llama_context * lctx, bool do_shift, const stream_copy_info & sc_info) {
|
||||
bool updated = false;
|
||||
|
||||
auto * sched = lctx->get_sched();
|
||||
|
|
@ -699,57 +671,10 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d
|
|||
}
|
||||
}}
|
||||
|
||||
if (!dinfo.empty()) {
|
||||
LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
|
||||
|
||||
// note: for now do not consider defrag for n_stream > 1
|
||||
auto & cells = v_cells[seq_to_stream[0]];
|
||||
auto & head = v_heads[seq_to_stream[0]];
|
||||
|
||||
// apply moves:
|
||||
{
|
||||
const auto n_kv = dinfo.ids.size();
|
||||
|
||||
for (uint32_t i = 0; i < n_kv; ++i) {
|
||||
assert(dinfo.ids[i] <= n_kv);
|
||||
|
||||
if (dinfo.ids[i] == n_kv || dinfo.ids[i] == i) {
|
||||
continue;
|
||||
}
|
||||
|
||||
cells.mv(i, dinfo.ids[i]);
|
||||
}
|
||||
|
||||
// reset the head so we can find the first free slot during the next ubatch
|
||||
head = 0;
|
||||
}
|
||||
|
||||
ggml_backend_sched_reset(sched);
|
||||
|
||||
auto * res = lctx->get_gf_res_reserve();
|
||||
|
||||
res->reset();
|
||||
|
||||
auto * gf = build_graph_defrag(res, lctx, dinfo);
|
||||
if (!ggml_backend_sched_alloc_graph(sched, gf)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to allocate compute graph for defrag\n", __func__);
|
||||
return updated;
|
||||
}
|
||||
|
||||
res->set_inputs(nullptr);
|
||||
|
||||
if (lctx->graph_compute(gf, false) != GGML_STATUS_SUCCESS) {
|
||||
LLAMA_LOG_ERROR("%s: failed to compute defrag\n", __func__);
|
||||
return updated;
|
||||
}
|
||||
|
||||
updated = true;
|
||||
}
|
||||
|
||||
return updated;
|
||||
}
|
||||
|
||||
llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch, bool cont) const {
|
||||
llama_kv_cache::slot_info llama_kv_cache::find_slot(const llama_ubatch & ubatch, bool cont) const {
|
||||
|
||||
if (debug > 0) {
|
||||
for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
|
||||
|
|
@ -948,7 +873,7 @@ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_
|
|||
return res;
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch) {
|
||||
void llama_kv_cache::apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch) {
|
||||
// keep track of the max sequence position that we would overwrite with this ubatch
|
||||
// for non-SWA cache, this would be always empty
|
||||
llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
|
||||
|
|
@ -1013,21 +938,21 @@ void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_u
|
|||
}
|
||||
}
|
||||
|
||||
bool llama_kv_cache_unified::get_can_shift() const {
|
||||
bool llama_kv_cache::get_can_shift() const {
|
||||
return true;
|
||||
}
|
||||
|
||||
uint32_t llama_kv_cache_unified::get_size() const {
|
||||
uint32_t llama_kv_cache::get_size() const {
|
||||
const auto & cells = v_cells[seq_to_stream[0]];
|
||||
|
||||
return cells.size();
|
||||
}
|
||||
|
||||
uint32_t llama_kv_cache_unified::get_n_stream() const {
|
||||
uint32_t llama_kv_cache::get_n_stream() const {
|
||||
return n_stream;
|
||||
}
|
||||
|
||||
bool llama_kv_cache_unified::get_has_shift() const {
|
||||
bool llama_kv_cache::get_has_shift() const {
|
||||
bool result = false;
|
||||
|
||||
for (uint32_t s = 0; s < n_stream; ++s) {
|
||||
|
|
@ -1037,7 +962,7 @@ bool llama_kv_cache_unified::get_has_shift() const {
|
|||
return result;
|
||||
}
|
||||
|
||||
uint32_t llama_kv_cache_unified::get_n_kv() const {
|
||||
uint32_t llama_kv_cache::get_n_kv() const {
|
||||
uint32_t result = 0;
|
||||
|
||||
for (uint32_t s = 0; s < n_stream; ++s) {
|
||||
|
|
@ -1049,11 +974,11 @@ uint32_t llama_kv_cache_unified::get_n_kv() const {
|
|||
return result;
|
||||
}
|
||||
|
||||
bool llama_kv_cache_unified::get_supports_set_rows() const {
|
||||
bool llama_kv_cache::get_supports_set_rows() const {
|
||||
return supports_set_rows;
|
||||
}
|
||||
|
||||
ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const {
|
||||
ggml_tensor * llama_kv_cache::get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const {
|
||||
const int32_t ikv = map_layer_ids.at(il);
|
||||
|
||||
auto * k = layers[ikv].k;
|
||||
|
|
@ -1073,7 +998,7 @@ ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il, uint
|
|||
ggml_row_size(k->type, n_embd_k_gqa*kv_size)*sinfo.s0);
|
||||
}
|
||||
|
||||
ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const {
|
||||
ggml_tensor * llama_kv_cache::get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const {
|
||||
const int32_t ikv = map_layer_ids.at(il);
|
||||
|
||||
auto * v = layers[ikv].v;
|
||||
|
|
@ -1105,7 +1030,7 @@ ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint
|
|||
ggml_row_size(v->type, kv_size*n_embd_v_gqa)*sinfo.s0);
|
||||
}
|
||||
|
||||
ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const {
|
||||
ggml_tensor * llama_kv_cache::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const {
|
||||
const int32_t ikv = map_layer_ids.at(il);
|
||||
|
||||
auto * k = layers[ikv].k;
|
||||
|
|
@ -1135,7 +1060,7 @@ ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_
|
|||
return ggml_cpy(ctx, k_cur, k_view);
|
||||
}
|
||||
|
||||
ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const {
|
||||
ggml_tensor * llama_kv_cache::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const {
|
||||
const int32_t ikv = map_layer_ids.at(il);
|
||||
|
||||
auto * v = layers[ikv].v;
|
||||
|
|
@ -1189,7 +1114,7 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
|
|||
return ggml_cpy(ctx, v_cur, v_view);
|
||||
}
|
||||
|
||||
ggml_tensor * llama_kv_cache_unified::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
|
||||
ggml_tensor * llama_kv_cache::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
|
||||
const uint32_t n_tokens = ubatch.n_tokens;
|
||||
|
||||
ggml_tensor * k_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens);
|
||||
|
|
@ -1199,7 +1124,7 @@ ggml_tensor * llama_kv_cache_unified::build_input_k_idxs(ggml_context * ctx, con
|
|||
return k_idxs;
|
||||
}
|
||||
|
||||
ggml_tensor * llama_kv_cache_unified::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
|
||||
ggml_tensor * llama_kv_cache::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
|
||||
const uint32_t n_tokens = ubatch.n_tokens;
|
||||
|
||||
ggml_tensor * v_idxs;
|
||||
|
|
@ -1215,7 +1140,7 @@ ggml_tensor * llama_kv_cache_unified::build_input_v_idxs(ggml_context * ctx, con
|
|||
return v_idxs;
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
|
||||
void llama_kv_cache::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
|
||||
if (!supports_set_rows) {
|
||||
return;
|
||||
}
|
||||
|
|
@ -1235,7 +1160,7 @@ void llama_kv_cache_unified::set_input_k_idxs(ggml_tensor * dst, const llama_uba
|
|||
}
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
|
||||
void llama_kv_cache::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
|
||||
if (!supports_set_rows) {
|
||||
return;
|
||||
}
|
||||
|
|
@ -1272,7 +1197,7 @@ void llama_kv_cache_unified::set_input_v_idxs(ggml_tensor * dst, const llama_uba
|
|||
}
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified::set_input_k_shift(ggml_tensor * dst) const {
|
||||
void llama_kv_cache::set_input_k_shift(ggml_tensor * dst) const {
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
|
||||
|
||||
int32_t * data = (int32_t *) dst->data;
|
||||
|
|
@ -1286,7 +1211,7 @@ void llama_kv_cache_unified::set_input_k_shift(ggml_tensor * dst) const {
|
|||
}
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
|
||||
void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
|
||||
const uint32_t n_tokens = ubatch->n_tokens;
|
||||
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
|
||||
|
|
@ -1358,7 +1283,7 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
|
|||
}
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
|
||||
void llama_kv_cache::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
|
||||
const int64_t n_tokens = ubatch->n_tokens;
|
||||
|
||||
GGML_ASSERT(n_stream == 1 && "TODO: support multiple streams");
|
||||
|
|
@ -1383,7 +1308,7 @@ void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama
|
|||
}
|
||||
}
|
||||
|
||||
size_t llama_kv_cache_unified::total_size() const {
|
||||
size_t llama_kv_cache::total_size() const {
|
||||
size_t size = 0;
|
||||
|
||||
for (const auto & buf : bufs) {
|
||||
|
|
@ -1393,7 +1318,7 @@ size_t llama_kv_cache_unified::total_size() const {
|
|||
return size;
|
||||
}
|
||||
|
||||
size_t llama_kv_cache_unified::size_k_bytes() const {
|
||||
size_t llama_kv_cache::size_k_bytes() const {
|
||||
size_t size_k_bytes = 0;
|
||||
|
||||
for (const auto & layer : layers) {
|
||||
|
|
@ -1403,7 +1328,7 @@ size_t llama_kv_cache_unified::size_k_bytes() const {
|
|||
return size_k_bytes;
|
||||
}
|
||||
|
||||
size_t llama_kv_cache_unified::size_v_bytes() const {
|
||||
size_t llama_kv_cache::size_v_bytes() const {
|
||||
size_t size_v_bytes = 0;
|
||||
|
||||
for (const auto & layer : layers) {
|
||||
|
|
@ -1413,7 +1338,7 @@ size_t llama_kv_cache_unified::size_v_bytes() const {
|
|||
return size_v_bytes;
|
||||
}
|
||||
|
||||
ggml_tensor * llama_kv_cache_unified::build_rope_shift(
|
||||
ggml_tensor * llama_kv_cache::build_rope_shift(
|
||||
const llama_cparams & cparams,
|
||||
ggml_context * ctx,
|
||||
ggml_tensor * cur,
|
||||
|
|
@ -1465,14 +1390,14 @@ ggml_tensor * llama_kv_cache_unified::build_rope_shift(
|
|||
|
||||
class llm_graph_input_k_shift : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_k_shift(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
|
||||
llm_graph_input_k_shift(const llama_kv_cache * kv_self) : kv_self(kv_self) {}
|
||||
virtual ~llm_graph_input_k_shift() = default;
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
|
||||
ggml_tensor * k_shift; // I32 [kv_size*n_stream]
|
||||
|
||||
const llama_kv_cache_unified * kv_self;
|
||||
const llama_kv_cache * kv_self;
|
||||
};
|
||||
|
||||
void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) {
|
||||
|
|
@ -1483,7 +1408,7 @@ void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) {
|
|||
}
|
||||
}
|
||||
|
||||
ggml_cgraph * llama_kv_cache_unified::build_graph_shift(llm_graph_result * res, llama_context * lctx) const {
|
||||
ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_context * lctx) const {
|
||||
auto * ctx = res->get_ctx();
|
||||
auto * gf = res->get_gf();
|
||||
|
||||
|
|
@ -1525,284 +1450,7 @@ ggml_cgraph * llama_kv_cache_unified::build_graph_shift(llm_graph_result * res,
|
|||
return gf;
|
||||
}
|
||||
|
||||
ggml_cgraph * llama_kv_cache_unified::build_graph_defrag(
|
||||
llm_graph_result * res,
|
||||
llama_context * lctx,
|
||||
const defrag_info & dinfo) const {
|
||||
auto * ctx = res->get_ctx();
|
||||
auto * gf = res->get_gf();
|
||||
|
||||
GGML_ASSERT(n_stream == 1 && "n_stream > 1 does not support defrag");
|
||||
|
||||
const auto & cells = v_cells[0];
|
||||
|
||||
const auto & ids = dinfo.ids;
|
||||
|
||||
const auto & cparams = lctx->get_cparams();
|
||||
|
||||
#if 0
|
||||
// CPU defrag
|
||||
//
|
||||
// TODO: optimizations are possible:
|
||||
// - multiple threads
|
||||
// - avoid copying to the host memory when already there
|
||||
//
|
||||
// likely not worth the effort, as we have ggml_graph based defrag
|
||||
//
|
||||
|
||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
|
||||
|
||||
const uint32_t kv_size = size;
|
||||
|
||||
std::vector<uint8_t> buf_k;
|
||||
std::vector<uint8_t> buf_v;
|
||||
|
||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||
const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
|
||||
const size_t k_size = ggml_row_size(k_l[il]->type, n_embd_k_gqa*kv_size);
|
||||
|
||||
const size_t v_size_el = ggml_type_size(v_l[il]->type);
|
||||
const size_t v_size = ggml_row_size (v_l[il]->type, n_embd_v_gqa*kv_size);
|
||||
|
||||
buf_k.resize(k_size);
|
||||
buf_v.resize(v_size);
|
||||
|
||||
ggml_backend_tensor_get(k_l[il], buf_k.data(), 0, buf_k.size());
|
||||
ggml_backend_tensor_get(v_l[il], buf_v.data(), 0, buf_v.size());
|
||||
|
||||
// batch move [i, i+nm) to [id, id+nm)
|
||||
// note: cells can move only to a lower index
|
||||
for (uint32_t i = 0; i < n_kv; ++i) {
|
||||
const uint32_t id = ids[i];
|
||||
|
||||
if (i == id || id == n_kv) {
|
||||
continue;
|
||||
}
|
||||
|
||||
uint32_t nm = 1;
|
||||
|
||||
while (i + nm < n_kv && ids[i + nm] == id + nm) {
|
||||
nm++;
|
||||
}
|
||||
|
||||
// move keys
|
||||
{
|
||||
const int64_t os = i*k_size_row;
|
||||
const int64_t od = id*k_size_row;
|
||||
|
||||
memcpy(buf_k.data() + od, buf_k.data() + os, nm*k_size_row);
|
||||
}
|
||||
|
||||
// move values (note: they are transposed)
|
||||
{
|
||||
const int64_t os = i;
|
||||
const int64_t od = id;
|
||||
|
||||
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
|
||||
memcpy(buf_v.data() + (od + j*kv_size)*v_size_el, buf_v.data() + (os + j*kv_size)*v_size_el, nm*v_size_el);
|
||||
}
|
||||
}
|
||||
|
||||
i += nm - 1;
|
||||
}
|
||||
|
||||
ggml_backend_tensor_set(k_l[il], buf_k.data(), 0, buf_k.size());
|
||||
ggml_backend_tensor_set(v_l[il], buf_v.data(), 0, buf_v.size());
|
||||
}
|
||||
#else
|
||||
for (uint32_t i = 0; i < ids.size(); ++i) {
|
||||
const uint32_t id = ids[i];
|
||||
|
||||
if (i == id || id == ids.size()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
uint32_t nm = 1;
|
||||
|
||||
while (i + nm < ids.size() && ids[i + nm] == id + nm) {
|
||||
nm++;
|
||||
}
|
||||
|
||||
for (const auto & layer : layers) {
|
||||
const uint32_t il = layer.il;
|
||||
|
||||
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
||||
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
||||
|
||||
ggml_tensor * view_k_src = ggml_view_2d(ctx, layer.k,
|
||||
n_embd_k_gqa, nm,
|
||||
ggml_row_size(layer.k->type, n_embd_k_gqa),
|
||||
ggml_row_size(layer.k->type, n_embd_k_gqa*i));
|
||||
|
||||
ggml_tensor * view_k_dst = ggml_view_2d(ctx, layer.k,
|
||||
n_embd_k_gqa, nm,
|
||||
ggml_row_size(layer.k->type, n_embd_k_gqa),
|
||||
ggml_row_size(layer.k->type, n_embd_k_gqa*id));
|
||||
|
||||
ggml_tensor * view_v_src;
|
||||
ggml_tensor * view_v_dst;
|
||||
|
||||
if (cparams.flash_attn) {
|
||||
// NOTE: the V cache is not transposed when using flash attention
|
||||
view_v_src = ggml_view_2d(ctx, layer.v,
|
||||
n_embd_v_gqa, nm,
|
||||
ggml_row_size(layer.v->type, n_embd_v_gqa),
|
||||
ggml_row_size(layer.v->type, n_embd_v_gqa*i));
|
||||
|
||||
view_v_dst = ggml_view_2d(ctx, layer.v,
|
||||
n_embd_v_gqa, nm,
|
||||
ggml_row_size(layer.v->type, n_embd_v_gqa),
|
||||
ggml_row_size(layer.v->type, n_embd_v_gqa*id));
|
||||
} else {
|
||||
view_v_src = ggml_view_2d(ctx, layer.v,
|
||||
nm, n_embd_v_gqa,
|
||||
ggml_row_size(layer.v->type, cells.size()),
|
||||
ggml_row_size(layer.v->type, i));
|
||||
|
||||
view_v_dst = ggml_view_2d(ctx, layer.v,
|
||||
nm, n_embd_v_gqa,
|
||||
ggml_row_size(layer.v->type, cells.size()),
|
||||
ggml_row_size(layer.v->type, id));
|
||||
}
|
||||
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx, view_k_src, view_k_dst));
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx, view_v_src, view_v_dst));
|
||||
}
|
||||
|
||||
i += nm - 1;
|
||||
}
|
||||
|
||||
//LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes);
|
||||
#endif
|
||||
|
||||
return gf;
|
||||
}
|
||||
|
||||
llama_kv_cache_unified::defrag_info llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) const {
|
||||
GGML_ASSERT(n_stream == 1 && "n_stream > 1 does not support defrag");
|
||||
|
||||
const auto & cells = v_cells[0];
|
||||
|
||||
const uint32_t n_layer = layers.size();
|
||||
|
||||
const uint32_t n_kv = cells.used_max_p1();
|
||||
const uint32_t n_used = cells.get_used();
|
||||
|
||||
assert(n_used <= n_kv);
|
||||
|
||||
//const int64_t t_start = ggml_time_us();
|
||||
|
||||
// number of cells moved
|
||||
uint32_t n_moves = 0;
|
||||
|
||||
// each move requires 6*n_layer tensors (see graph_build_kv_self_defrag)
|
||||
// - source view, destination view, copy operation
|
||||
// - x2 for keys and values
|
||||
//const uint32_t max_moves = max_nodes()/(6*n_layer);
|
||||
// TODO: tmp fix https://github.com/ggerganov/llama.cpp/issues/6685#issuecomment-2057579516
|
||||
const uint32_t max_moves = (n_max_nodes - 2*n_layer)/(6*n_layer);
|
||||
|
||||
// determine which KV cells to move where
|
||||
defrag_info res;
|
||||
auto & ids = res.ids;
|
||||
|
||||
ids.resize(n_kv, n_kv);
|
||||
|
||||
for (uint32_t i0 = 0; i0 < n_used; ++i0) {
|
||||
if (!cells.is_empty(i0)) {
|
||||
ids[i0] = i0;
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
// found a hole - fill it with data from the end of the cache
|
||||
|
||||
uint32_t nh = 1;
|
||||
|
||||
// determine the size of the hole
|
||||
while (i0 + nh < n_used && cells.is_empty(i0 + nh)) {
|
||||
nh++;
|
||||
}
|
||||
|
||||
uint32_t nf = 0;
|
||||
uint32_t is = n_kv - 1;
|
||||
|
||||
// starting from the end, find nh non-empty cells
|
||||
for (; is > i0; --is) {
|
||||
if (cells.is_empty(is) || ids[is] != n_kv) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// non-empty cell which is not yet moved
|
||||
nf++;
|
||||
|
||||
if (nf == nh) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// this can only happen if `n_used` is not accurate, which would be a bug
|
||||
GGML_ASSERT(nf == nh && "KV defrag bug: nf != nh");
|
||||
|
||||
nf = 0;
|
||||
|
||||
uint32_t i1 = is;
|
||||
|
||||
// are we moving a continuous block of memory?
|
||||
bool cont = false;
|
||||
|
||||
// should we stop searching for the next move?
|
||||
bool stop = false;
|
||||
|
||||
// go back and move the nf cells to the hole
|
||||
for (; i1 < n_kv; ++i1) {
|
||||
if (cells.is_empty(i1) || ids[i1] != n_kv) {
|
||||
if (n_moves == max_moves) {
|
||||
stop = true;
|
||||
break;
|
||||
}
|
||||
|
||||
cont = false;
|
||||
continue;
|
||||
}
|
||||
|
||||
// this cell goes to (i0 + nf)
|
||||
ids[i1] = i0 + nf;
|
||||
|
||||
if (!cont) {
|
||||
n_moves++;
|
||||
cont = true;
|
||||
}
|
||||
|
||||
nf++;
|
||||
|
||||
if (nf == nh) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (stop || n_moves == max_moves) {
|
||||
break;
|
||||
}
|
||||
|
||||
//LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, i1 + 1, i0, i0 + nh);
|
||||
|
||||
i0 += nh - 1;
|
||||
}
|
||||
|
||||
if (n_moves == 0) {
|
||||
return {};
|
||||
}
|
||||
|
||||
LLAMA_LOG_DEBUG("%s: (tmp log) KV defrag cell moves: %u\n", __func__, n_moves);
|
||||
|
||||
LLAMA_LOG_DEBUG("%s: expected gf nodes: %u\n", __func__, 6*n_moves*n_layer);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const {
|
||||
bool llama_kv_cache::is_masked_swa(llama_pos p0, llama_pos p1) const {
|
||||
assert(p0 >= 0 && p1 >= 0);
|
||||
|
||||
switch (swa_type) {
|
||||
|
|
@ -1828,7 +1476,7 @@ bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const {
|
|||
return false;
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
|
||||
void llama_kv_cache::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
|
||||
GGML_UNUSED(flags);
|
||||
|
||||
io.write(&n_stream, sizeof(n_stream));
|
||||
|
|
@ -1881,7 +1529,7 @@ void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq
|
|||
}
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
|
||||
void llama_kv_cache::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
|
||||
GGML_UNUSED(flags);
|
||||
|
||||
GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()));
|
||||
|
|
@ -1917,7 +1565,7 @@ void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_i
|
|||
}
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const cell_ranges_t & cr, llama_seq_id seq_id) const {
|
||||
void llama_kv_cache::state_write_meta(llama_io_write_i & io, const cell_ranges_t & cr, llama_seq_id seq_id) const {
|
||||
const auto & cells = v_cells[cr.strm];
|
||||
|
||||
for (const auto & range : cr.data) {
|
||||
|
|
@ -1945,7 +1593,7 @@ void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const cell_
|
|||
}
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const cell_ranges_t & cr) const {
|
||||
void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t & cr) const {
|
||||
const auto & cells = v_cells[cr.strm];
|
||||
|
||||
const uint32_t v_trans = this->v_trans ? 1 : 0;
|
||||
|
|
@ -2040,7 +1688,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const cell_
|
|||
}
|
||||
}
|
||||
|
||||
bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id) {
|
||||
bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id) {
|
||||
auto & cells = v_cells[strm];
|
||||
auto & head = v_heads[strm];
|
||||
|
||||
|
|
@ -2137,7 +1785,7 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t strm
|
|||
return true;
|
||||
}
|
||||
|
||||
bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count) {
|
||||
bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count) {
|
||||
auto & cells = v_cells[strm];
|
||||
auto & head = v_heads[strm];
|
||||
|
||||
|
|
@ -2274,13 +1922,13 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t strm
|
|||
}
|
||||
|
||||
//
|
||||
// llama_kv_cache_unified_context
|
||||
// llama_kv_cache_context
|
||||
//
|
||||
|
||||
llama_kv_cache_unified_context::llama_kv_cache_unified_context(llama_memory_status status) : status(status) {}
|
||||
llama_kv_cache_context::llama_kv_cache_context(llama_memory_status status) : status(status) {}
|
||||
|
||||
llama_kv_cache_unified_context::llama_kv_cache_unified_context(
|
||||
llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
|
||||
llama_kv_cache_context::llama_kv_cache_context(
|
||||
llama_kv_cache * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
|
||||
n_kv = kv->get_size();
|
||||
|
||||
const uint32_t n_stream = kv->get_n_stream();
|
||||
|
|
@ -2296,26 +1944,25 @@ llama_kv_cache_unified_context::llama_kv_cache_unified_context(
|
|||
}
|
||||
}
|
||||
|
||||
llama_kv_cache_unified_context::llama_kv_cache_unified_context(
|
||||
llama_kv_cache_unified * kv,
|
||||
llama_kv_cache_context::llama_kv_cache_context(
|
||||
llama_kv_cache * kv,
|
||||
llama_context * lctx,
|
||||
bool do_shift,
|
||||
defrag_info dinfo,
|
||||
stream_copy_info sc_info) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), lctx(lctx), do_shift(do_shift), dinfo(std::move(dinfo)), sc_info(std::move(sc_info)) {
|
||||
if (!do_shift && this->dinfo.empty() && this->sc_info.empty()) {
|
||||
stream_copy_info sc_info) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), lctx(lctx), do_shift(do_shift), sc_info(std::move(sc_info)) {
|
||||
if (!do_shift && this->sc_info.empty()) {
|
||||
status = LLAMA_MEMORY_STATUS_NO_UPDATE;
|
||||
}
|
||||
}
|
||||
|
||||
llama_kv_cache_unified_context::llama_kv_cache_unified_context(
|
||||
llama_kv_cache_unified * kv,
|
||||
llama_kv_cache_unified::slot_info_vec_t sinfos,
|
||||
llama_kv_cache_context::llama_kv_cache_context(
|
||||
llama_kv_cache * kv,
|
||||
llama_kv_cache::slot_info_vec_t sinfos,
|
||||
std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sinfos(std::move(sinfos)), ubatches(std::move(ubatches)) {
|
||||
}
|
||||
|
||||
llama_kv_cache_unified_context::~llama_kv_cache_unified_context() = default;
|
||||
llama_kv_cache_context::~llama_kv_cache_context() = default;
|
||||
|
||||
bool llama_kv_cache_unified_context::next() {
|
||||
bool llama_kv_cache_context::next() {
|
||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||
|
||||
if (++i_cur >= ubatches.size()) {
|
||||
|
|
@ -2325,12 +1972,12 @@ bool llama_kv_cache_unified_context::next() {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool llama_kv_cache_unified_context::apply() {
|
||||
bool llama_kv_cache_context::apply() {
|
||||
assert(!llama_memory_status_is_fail(status));
|
||||
|
||||
// no ubatches -> this is a KV cache update
|
||||
if (ubatches.empty()) {
|
||||
kv->update(lctx, do_shift, dinfo, sc_info);
|
||||
kv->update(lctx, do_shift, sc_info);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
|
@ -2342,69 +1989,69 @@ bool llama_kv_cache_unified_context::apply() {
|
|||
return true;
|
||||
}
|
||||
|
||||
llama_memory_status llama_kv_cache_unified_context::get_status() const {
|
||||
llama_memory_status llama_kv_cache_context::get_status() const {
|
||||
return status;
|
||||
}
|
||||
|
||||
const llama_ubatch & llama_kv_cache_unified_context::get_ubatch() const {
|
||||
const llama_ubatch & llama_kv_cache_context::get_ubatch() const {
|
||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||
|
||||
return ubatches[i_cur];
|
||||
}
|
||||
|
||||
uint32_t llama_kv_cache_unified_context::get_n_kv() const {
|
||||
uint32_t llama_kv_cache_context::get_n_kv() const {
|
||||
return n_kv;
|
||||
}
|
||||
|
||||
bool llama_kv_cache_unified_context::get_supports_set_rows() const {
|
||||
bool llama_kv_cache_context::get_supports_set_rows() const {
|
||||
return kv->get_supports_set_rows();
|
||||
}
|
||||
|
||||
ggml_tensor * llama_kv_cache_unified_context::get_k(ggml_context * ctx, int32_t il) const {
|
||||
ggml_tensor * llama_kv_cache_context::get_k(ggml_context * ctx, int32_t il) const {
|
||||
return kv->get_k(ctx, il, n_kv, sinfos[i_cur]);
|
||||
}
|
||||
|
||||
ggml_tensor * llama_kv_cache_unified_context::get_v(ggml_context * ctx, int32_t il) const {
|
||||
ggml_tensor * llama_kv_cache_context::get_v(ggml_context * ctx, int32_t il) const {
|
||||
return kv->get_v(ctx, il, n_kv, sinfos[i_cur]);
|
||||
}
|
||||
|
||||
ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const {
|
||||
ggml_tensor * llama_kv_cache_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const {
|
||||
return kv->cpy_k(ctx, k_cur, k_idxs, il, sinfos[i_cur]);
|
||||
}
|
||||
|
||||
ggml_tensor * llama_kv_cache_unified_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const {
|
||||
ggml_tensor * llama_kv_cache_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const {
|
||||
return kv->cpy_v(ctx, v_cur, v_idxs, il, sinfos[i_cur]);
|
||||
}
|
||||
|
||||
ggml_tensor * llama_kv_cache_unified_context::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
|
||||
ggml_tensor * llama_kv_cache_context::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
|
||||
return kv->build_input_k_idxs(ctx, ubatch);
|
||||
}
|
||||
|
||||
ggml_tensor * llama_kv_cache_unified_context::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
|
||||
ggml_tensor * llama_kv_cache_context::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
|
||||
return kv->build_input_v_idxs(ctx, ubatch);
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified_context::set_input_k_shift(ggml_tensor * dst) const {
|
||||
void llama_kv_cache_context::set_input_k_shift(ggml_tensor * dst) const {
|
||||
kv->set_input_k_shift(dst);
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified_context::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
|
||||
void llama_kv_cache_context::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
|
||||
kv->set_input_k_idxs(dst, ubatch, sinfos[i_cur]);
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified_context::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
|
||||
void llama_kv_cache_context::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
|
||||
kv->set_input_v_idxs(dst, ubatch, sinfos[i_cur]);
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified_context::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
|
||||
void llama_kv_cache_context::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
|
||||
kv->set_input_kq_mask(dst, ubatch, causal_attn);
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified_context::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
|
||||
void llama_kv_cache_context::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
|
||||
kv->set_input_pos_bucket(dst, ubatch);
|
||||
}
|
||||
|
||||
uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) {
|
||||
uint32_t llama_kv_cache::get_padding(const llama_cparams & cparams) {
|
||||
// the FA kernels require padding to avoid extra runtime boundary checks
|
||||
return cparams.flash_attn ? 256u : 32u;
|
||||
}
|
||||
|
|
@ -14,27 +14,16 @@ struct llama_model;
|
|||
struct llama_context;
|
||||
|
||||
//
|
||||
// llama_kv_cache_unified
|
||||
// llama_kv_cache
|
||||
//
|
||||
|
||||
class llama_kv_cache_unified : public llama_memory_i {
|
||||
class llama_kv_cache : public llama_memory_i {
|
||||
public:
|
||||
static uint32_t get_padding(const llama_cparams & cparams);
|
||||
|
||||
// this callback is used to filter out layers that should not be included in the cache
|
||||
using layer_filter_cb = std::function<bool(int32_t il)>;
|
||||
|
||||
struct defrag_info {
|
||||
bool empty() const {
|
||||
return ids.empty();
|
||||
}
|
||||
|
||||
// contains information about which cell moves where:
|
||||
// - cell i moves to ids[i]
|
||||
// - if ids[i] == i || ids[i] == ids.size(), then cell i is not moved
|
||||
std::vector<uint32_t> ids;
|
||||
};
|
||||
|
||||
struct stream_copy_info {
|
||||
bool empty() const {
|
||||
assert(ssrc.size() == sdst.size());
|
||||
|
|
@ -92,7 +81,7 @@ public:
|
|||
|
||||
using slot_info_vec_t = std::vector<slot_info>;
|
||||
|
||||
llama_kv_cache_unified(
|
||||
llama_kv_cache(
|
||||
const llama_model & model,
|
||||
layer_filter_cb && filter,
|
||||
ggml_type type_k,
|
||||
|
|
@ -106,7 +95,7 @@ public:
|
|||
uint32_t n_swa,
|
||||
llama_swa_type swa_type);
|
||||
|
||||
~llama_kv_cache_unified() = default;
|
||||
~llama_kv_cache() = default;
|
||||
|
||||
//
|
||||
// llama_memory_i
|
||||
|
|
@ -140,7 +129,7 @@ public:
|
|||
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
|
||||
|
||||
//
|
||||
// llama_kv_cache_unified specific API
|
||||
// llama_kv_cache specific API
|
||||
//
|
||||
|
||||
uint32_t get_size() const;
|
||||
|
|
@ -173,7 +162,7 @@ public:
|
|||
// return empty vector on failure
|
||||
slot_info_vec_t prepare(const std::vector<llama_ubatch> & ubatches);
|
||||
|
||||
bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo, const stream_copy_info & sc_info);
|
||||
bool update(llama_context * lctx, bool do_shift, const stream_copy_info & sc_info);
|
||||
|
||||
// find a slot of kv cells that can hold the ubatch
|
||||
// if cont == true, then the slot must be continuous
|
||||
|
|
@ -241,7 +230,7 @@ private:
|
|||
// note: this is not part of the KV state and it's only used to speed-up the find_slot() method
|
||||
std::vector<uint32_t> v_heads;
|
||||
|
||||
std::vector<llama_kv_cells_unified> v_cells;
|
||||
std::vector<llama_kv_cells> v_cells;
|
||||
|
||||
// maps from a sequence id to a stream id
|
||||
std::vector<uint32_t> seq_to_stream;
|
||||
|
|
@ -254,9 +243,6 @@ private:
|
|||
// model layer id -> KV cache layer id
|
||||
std::unordered_map<int32_t, int32_t> map_layer_ids;
|
||||
|
||||
// return non-empty vector if cells have been moved
|
||||
defrag_info defrag_prepare(int32_t n_max_nodes) const;
|
||||
|
||||
size_t total_size() const;
|
||||
|
||||
size_t size_k_bytes() const;
|
||||
|
|
@ -277,11 +263,6 @@ private:
|
|||
llm_graph_result * res,
|
||||
llama_context * lctx) const;
|
||||
|
||||
ggml_cgraph * build_graph_defrag(
|
||||
llm_graph_result * res,
|
||||
llama_context * lctx,
|
||||
const defrag_info & dinfo) const;
|
||||
|
||||
struct cell_ranges_t {
|
||||
uint32_t strm;
|
||||
|
||||
|
|
@ -295,35 +276,33 @@ private:
|
|||
bool state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count);
|
||||
};
|
||||
|
||||
class llama_kv_cache_unified_context : public llama_memory_context_i {
|
||||
class llama_kv_cache_context : public llama_memory_context_i {
|
||||
public:
|
||||
// some shorthands
|
||||
using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
|
||||
using defrag_info = llama_kv_cache_unified::defrag_info;
|
||||
using stream_copy_info = llama_kv_cache_unified::stream_copy_info;
|
||||
using slot_info_vec_t = llama_kv_cache::slot_info_vec_t;
|
||||
using stream_copy_info = llama_kv_cache::stream_copy_info;
|
||||
|
||||
// used for errors
|
||||
llama_kv_cache_unified_context(llama_memory_status status);
|
||||
llama_kv_cache_context(llama_memory_status status);
|
||||
|
||||
// used to create a full-cache context
|
||||
llama_kv_cache_unified_context(
|
||||
llama_kv_cache_unified * kv);
|
||||
llama_kv_cache_context(
|
||||
llama_kv_cache * kv);
|
||||
|
||||
// used to create an update context
|
||||
llama_kv_cache_unified_context(
|
||||
llama_kv_cache_unified * kv,
|
||||
llama_kv_cache_context(
|
||||
llama_kv_cache * kv,
|
||||
llama_context * lctx,
|
||||
bool do_shift,
|
||||
defrag_info dinfo,
|
||||
stream_copy_info sc_info);
|
||||
|
||||
// used to create a batch procesing context from a batch
|
||||
llama_kv_cache_unified_context(
|
||||
llama_kv_cache_unified * kv,
|
||||
llama_kv_cache_context(
|
||||
llama_kv_cache * kv,
|
||||
slot_info_vec_t sinfos,
|
||||
std::vector<llama_ubatch> ubatches);
|
||||
|
||||
virtual ~llama_kv_cache_unified_context();
|
||||
virtual ~llama_kv_cache_context();
|
||||
|
||||
//
|
||||
// llama_memory_context_i
|
||||
|
|
@ -336,7 +315,7 @@ public:
|
|||
const llama_ubatch & get_ubatch() const override;
|
||||
|
||||
//
|
||||
// llama_kv_cache_unified_context specific API
|
||||
// llama_kv_cache_context specific API
|
||||
//
|
||||
|
||||
uint32_t get_n_kv() const;
|
||||
|
|
@ -365,7 +344,7 @@ public:
|
|||
private:
|
||||
llama_memory_status status;
|
||||
|
||||
llama_kv_cache_unified * kv;
|
||||
llama_kv_cache * kv;
|
||||
llama_context * lctx;
|
||||
|
||||
//
|
||||
|
|
@ -374,8 +353,6 @@ private:
|
|||
|
||||
bool do_shift = false;
|
||||
|
||||
defrag_info dinfo;
|
||||
|
||||
stream_copy_info sc_info;
|
||||
|
||||
//
|
||||
|
|
@ -11,7 +11,7 @@
|
|||
|
||||
// meta information about KV cells that can be part of multiple sequences at the same time
|
||||
// TODO: add unit tests
|
||||
class llama_kv_cells_unified {
|
||||
class llama_kv_cells {
|
||||
public:
|
||||
void reset() {
|
||||
for (uint32_t i = 0; i < pos.size(); ++i) {
|
||||
|
|
@ -77,30 +77,30 @@ public:
|
|||
}
|
||||
|
||||
// move cell isrc to idst (used during defrag)
|
||||
void mv(uint32_t isrc, uint32_t idst) {
|
||||
assert(isrc < pos.size());
|
||||
assert(idst < pos.size());
|
||||
//void mv(uint32_t isrc, uint32_t idst) {
|
||||
// assert(isrc < pos.size());
|
||||
// assert(idst < pos.size());
|
||||
|
||||
assert(pos[idst] == -1);
|
||||
assert(pos[isrc] != -1);
|
||||
// assert(pos[idst] == -1);
|
||||
// assert(pos[isrc] != -1);
|
||||
|
||||
pos [idst] = pos [isrc];
|
||||
shift[idst] = shift[isrc];
|
||||
seq [idst] = seq [isrc];
|
||||
// pos [idst] = pos [isrc];
|
||||
// shift[idst] = shift[isrc];
|
||||
// seq [idst] = seq [isrc];
|
||||
|
||||
pos [isrc] = -1;
|
||||
shift[isrc] = 0;
|
||||
seq [isrc].reset();
|
||||
// pos [isrc] = -1;
|
||||
// shift[isrc] = 0;
|
||||
// seq [isrc].reset();
|
||||
|
||||
used.erase (isrc);
|
||||
used.insert(idst);
|
||||
}
|
||||
// used.erase (isrc);
|
||||
// used.insert(idst);
|
||||
//}
|
||||
|
||||
// copy the state of cells [i, i + n) (used for save/restore the state of the cells)
|
||||
llama_kv_cells_unified cp(uint32_t i, uint32_t n) const {
|
||||
llama_kv_cells cp(uint32_t i, uint32_t n) const {
|
||||
assert(i + n <= pos.size());
|
||||
|
||||
llama_kv_cells_unified res;
|
||||
llama_kv_cells res;
|
||||
|
||||
res.resize(n);
|
||||
|
||||
|
|
@ -117,8 +117,8 @@ public:
|
|||
}
|
||||
|
||||
// copy the state of cells [idxs[0], idxs[1], ..., idxs[idxs.size() - 1])
|
||||
llama_kv_cells_unified cp(const std::vector<uint32_t> & idxs) const {
|
||||
llama_kv_cells_unified res;
|
||||
llama_kv_cells cp(const std::vector<uint32_t> & idxs) const {
|
||||
llama_kv_cells res;
|
||||
|
||||
res.resize(idxs.size());
|
||||
|
||||
|
|
@ -135,7 +135,7 @@ public:
|
|||
}
|
||||
|
||||
// set the state of cells [i, i + other.pos.size()) (used for save/restore the state of the cells)
|
||||
void set(uint32_t i, const llama_kv_cells_unified & other) {
|
||||
void set(uint32_t i, const llama_kv_cells & other) {
|
||||
assert(i + other.pos.size() <= pos.size());
|
||||
|
||||
for (uint32_t j = 0; j < other.pos.size(); ++j) {
|
||||
|
|
@ -165,7 +165,7 @@ public:
|
|||
}
|
||||
|
||||
// set the state of cells [idxs[0], idxs[1], ..., idxs[idxs.size() - 1])
|
||||
void set(const std::vector<uint32_t> & idxs, const llama_kv_cells_unified & other) {
|
||||
void set(const std::vector<uint32_t> & idxs, const llama_kv_cells & other) {
|
||||
assert(idxs.size() == other.pos.size());
|
||||
|
||||
for (uint32_t j = 0; j < other.pos.size(); ++j) {
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ llama_memory_hybrid::llama_memory_hybrid(
|
|||
layer_filter_cb && filter_attn,
|
||||
layer_filter_cb && filter_recr) :
|
||||
hparams(model.hparams),
|
||||
mem_attn(new llama_kv_cache_unified(
|
||||
mem_attn(new llama_kv_cache(
|
||||
model,
|
||||
filter_attn == nullptr ?
|
||||
[&](int32_t il) { return !hparams.is_recurrent(il); }
|
||||
|
|
@ -179,7 +179,7 @@ void llama_memory_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id,
|
|||
mem_recr->state_read(io, seq_id);
|
||||
}
|
||||
|
||||
llama_kv_cache_unified * llama_memory_hybrid::get_mem_attn() const {
|
||||
llama_kv_cache * llama_memory_hybrid::get_mem_attn() const {
|
||||
return mem_attn.get();
|
||||
}
|
||||
|
||||
|
|
@ -210,7 +210,7 @@ llama_memory_hybrid_context::llama_memory_hybrid_context(
|
|||
std::vector<llama_ubatch> ubatches) :
|
||||
ubatches(std::move(ubatches)),
|
||||
// note: here we copy the ubatches. not sure if this is ideal
|
||||
ctx_attn(new llama_kv_cache_unified_context(mem->get_mem_attn(), std::move(sinfos_attn), this->ubatches)),
|
||||
ctx_attn(new llama_kv_cache_context(mem->get_mem_attn(), std::move(sinfos_attn), this->ubatches)),
|
||||
ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)),
|
||||
status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
|
||||
}
|
||||
|
|
@ -248,8 +248,8 @@ const llama_ubatch & llama_memory_hybrid_context::get_ubatch() const {
|
|||
return ubatches[i_next];
|
||||
}
|
||||
|
||||
const llama_kv_cache_unified_context * llama_memory_hybrid_context::get_attn() const {
|
||||
return static_cast<const llama_kv_cache_unified_context *>(ctx_attn.get());
|
||||
const llama_kv_cache_context * llama_memory_hybrid_context::get_attn() const {
|
||||
return static_cast<const llama_kv_cache_context *>(ctx_attn.get());
|
||||
}
|
||||
|
||||
const llama_memory_recurrent_context * llama_memory_hybrid_context::get_recr() const {
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
#include "llama-batch.h"
|
||||
#include "llama-graph.h"
|
||||
#include "llama-kv-cache-unified.h"
|
||||
#include "llama-kv-cache.h"
|
||||
#include "llama-memory.h"
|
||||
#include "llama-memory-recurrent.h"
|
||||
|
||||
|
|
@ -13,7 +13,7 @@
|
|||
// llama_memory_hybrid
|
||||
//
|
||||
|
||||
// utilizes instances of llama_memory_recurrent and llama_kv_cache_unified to
|
||||
// utilizes instances of llama_memory_recurrent and llama_kv_cache to
|
||||
// support models where each layer may be either attention-based or recurrent
|
||||
|
||||
class llama_memory_hybrid : public llama_memory_i {
|
||||
|
|
@ -81,19 +81,19 @@ public:
|
|||
// llama_memory_hybrid specific API
|
||||
//
|
||||
|
||||
llama_kv_cache_unified * get_mem_attn() const;
|
||||
llama_kv_cache * get_mem_attn() const;
|
||||
llama_memory_recurrent * get_mem_recr() const;
|
||||
|
||||
private:
|
||||
const llama_hparams & hparams;
|
||||
|
||||
const std::unique_ptr<llama_kv_cache_unified> mem_attn;
|
||||
const std::unique_ptr<llama_kv_cache> mem_attn;
|
||||
const std::unique_ptr<llama_memory_recurrent> mem_recr;
|
||||
};
|
||||
|
||||
class llama_memory_hybrid_context : public llama_memory_context_i {
|
||||
public:
|
||||
using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
|
||||
using slot_info_vec_t = llama_kv_cache::slot_info_vec_t;
|
||||
|
||||
// init failure
|
||||
explicit llama_memory_hybrid_context(llama_memory_status status);
|
||||
|
|
@ -125,7 +125,7 @@ public:
|
|||
// llama_memory_hybrid_context
|
||||
//
|
||||
|
||||
const llama_kv_cache_unified_context * get_attn() const;
|
||||
const llama_kv_cache_context * get_attn() const;
|
||||
const llama_memory_recurrent_context * get_recr() const;
|
||||
|
||||
private:
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@
|
|||
//
|
||||
|
||||
// TODO: extract the cache state used for graph computation into llama_memory_recurrent_context_i
|
||||
// see the implementation of llama_kv_cache_unified_context_i for an example how to do it
|
||||
// see the implementation of llama_kv_cache_context_i for an example how to do it
|
||||
class llama_memory_recurrent : public llama_memory_i {
|
||||
public:
|
||||
|
||||
|
|
|
|||
|
|
@ -36,8 +36,8 @@ bool llama_memory_status_is_fail(llama_memory_status status);
|
|||
|
||||
// the interface for managing the memory context during batch processing
|
||||
// this interface is implemented per memory type. see:
|
||||
// - llama_kv_cache_unified_context
|
||||
// - llama_kv_cache_unified_iswa_context
|
||||
// - llama_kv_cache_context
|
||||
// - llama_kv_cache_iswa_context
|
||||
// ...
|
||||
//
|
||||
// the only method that should mutate the memory and the memory context is llama_memory_i::apply()
|
||||
|
|
@ -77,7 +77,7 @@ struct llama_memory_i {
|
|||
// simulate full cache, used for allocating worst-case compute buffers
|
||||
virtual llama_memory_context_ptr init_full() = 0;
|
||||
|
||||
// prepare for any pending memory updates, such as shifts, defrags, etc.
|
||||
// prepare for any pending memory updates, such as shifts, copies, etc.
|
||||
// status == LLAMA_MEMORY_STATUS_NO_UPDATE if there is nothing to update
|
||||
virtual llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) = 0;
|
||||
|
||||
|
|
@ -109,8 +109,3 @@ struct llama_memory_i {
|
|||
};
|
||||
|
||||
using llama_memory_ptr = std::unique_ptr<llama_memory_i>;
|
||||
|
||||
// TODO: temporary until the llama_kv_cache is removed from the public API
|
||||
struct llama_kv_cache : public llama_memory_i {
|
||||
virtual ~llama_kv_cache() = default;
|
||||
};
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -11,8 +11,8 @@ static bool old_mixtral_warning_showed = false;
|
|||
#include "llama-vocab.cpp"
|
||||
#include "llama-grammar.cpp"
|
||||
#include "llama-sampling.cpp"
|
||||
#include "llama-kv-cache-unified.cpp"
|
||||
#include "llama-kv-cache-unified-iswa.cpp"
|
||||
#include "llama-kv-cache.cpp"
|
||||
#include "llama-kv-cache-iswa.cpp"
|
||||
#include "llama-memory-hybrid.cpp"
|
||||
#include "llama-memory-recurrent.cpp"
|
||||
#include "llama-model-loader.cpp"
|
||||
|
|
|
|||
|
|
@ -3689,7 +3689,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
|
|||
const int height = img->ny;
|
||||
const int total_factor = params.patch_size * params.proj_scale_factor;
|
||||
constexpr int min_image_tokens = 64;
|
||||
constexpr int max_image_tokens = 256;
|
||||
constexpr int max_image_tokens = 1024;
|
||||
const float min_pixels = min_image_tokens * total_factor * total_factor;
|
||||
const float max_pixels = max_image_tokens * total_factor * total_factor;
|
||||
|
||||
|
|
|
|||
|
|
@ -274,7 +274,6 @@ def start_server_background(args):
|
|||
server_args.extend(['--batch-size', args.batch_size])
|
||||
server_args.extend(['--ubatch-size', args.ubatch_size])
|
||||
server_args.extend(['--n-predict', args.max_tokens * 2])
|
||||
server_args.extend(['--defrag-thold', "0.1"])
|
||||
server_args.append('--cont-batching')
|
||||
server_args.append('--metrics')
|
||||
server_args.append('--flash-attn')
|
||||
|
|
|
|||
|
|
@ -4309,6 +4309,7 @@ int main(int argc, char ** argv) {
|
|||
};
|
||||
|
||||
const auto handle_api_show = [&ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) {
|
||||
bool has_mtmd = ctx_server.mctx != nullptr;
|
||||
json data = {
|
||||
{
|
||||
"template", common_chat_templates_source(ctx_server.chat_templates.get()),
|
||||
|
|
@ -4330,7 +4331,7 @@ int main(int argc, char ** argv) {
|
|||
{"quantization_level", ""}
|
||||
}},
|
||||
{"model_info", ""},
|
||||
{"capabilities", {"completion"}}
|
||||
{"capabilities", has_mtmd ? json({"completion","multimodal"}) : json({"completion"})}
|
||||
};
|
||||
|
||||
res_ok(res, data);
|
||||
|
|
@ -4356,56 +4357,15 @@ int main(int argc, char ** argv) {
|
|||
// TODO: this log can become very long, put it behind a flag or think about a more compact format
|
||||
//SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str());
|
||||
|
||||
// process files
|
||||
mtmd::bitmaps bitmaps;
|
||||
const bool has_mtmd = ctx_server.mctx != nullptr;
|
||||
{
|
||||
if (!has_mtmd && !files.empty()) {
|
||||
throw std::runtime_error("This server does not support multimodal");
|
||||
}
|
||||
for (auto & file : files) {
|
||||
mtmd::bitmap bmp(mtmd_helper_bitmap_init_from_buf(ctx_server.mctx, file.data(), file.size()));
|
||||
if (!bmp.ptr) {
|
||||
throw std::runtime_error("Failed to load image or audio file");
|
||||
}
|
||||
// calculate bitmap hash (for KV caching)
|
||||
std::string hash = fnv_hash(bmp.data(), bmp.n_bytes());
|
||||
bmp.set_id(hash.c_str());
|
||||
bitmaps.entries.push_back(std::move(bmp));
|
||||
}
|
||||
}
|
||||
|
||||
// process prompt
|
||||
std::vector<server_tokens> inputs;
|
||||
|
||||
if (oaicompat && has_mtmd) {
|
||||
// multimodal
|
||||
std::string prompt_str = prompt.get<std::string>();
|
||||
mtmd_input_text inp_txt = {
|
||||
prompt_str.c_str(),
|
||||
/* add_special */ true,
|
||||
/* parse_special */ true,
|
||||
};
|
||||
mtmd::input_chunks chunks(mtmd_input_chunks_init());
|
||||
auto bitmaps_c_ptr = bitmaps.c_ptr();
|
||||
int32_t tokenized = mtmd_tokenize(ctx_server.mctx,
|
||||
chunks.ptr.get(),
|
||||
&inp_txt,
|
||||
bitmaps_c_ptr.data(),
|
||||
bitmaps_c_ptr.size());
|
||||
if (tokenized != 0) {
|
||||
throw std::runtime_error("Failed to tokenize prompt");
|
||||
}
|
||||
|
||||
server_tokens tmp(chunks, true);
|
||||
inputs.push_back(std::move(tmp));
|
||||
if (oaicompat && ctx_server.mctx != nullptr) {
|
||||
// This is the case used by OAI compatible chat path with MTMD. TODO It can be moved to the path below.
|
||||
inputs.push_back(process_mtmd_prompt(ctx_server.mctx, prompt.get<std::string>(), files));
|
||||
} else {
|
||||
// non-multimodal version
|
||||
auto tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true);
|
||||
for (auto & p : tokenized_prompts) {
|
||||
auto tmp = server_tokens(p, ctx_server.mctx != nullptr);
|
||||
inputs.push_back(std::move(tmp));
|
||||
}
|
||||
// Everything else, including multimodal completions.
|
||||
inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true);
|
||||
}
|
||||
|
||||
tasks.reserve(inputs.size());
|
||||
|
|
@ -4574,7 +4534,7 @@ int main(int argc, char ** argv) {
|
|||
data["input_extra"] = input_extra; // default to empty array if it's not exist
|
||||
|
||||
std::string prompt = json_value(data, "prompt", std::string());
|
||||
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, false, true);
|
||||
std::vector<server_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, false, true);
|
||||
SRV_DBG("creating infill tasks, n_prompts = %d\n", (int) tokenized_prompts.size());
|
||||
data["prompt"] = format_infill(
|
||||
ctx_server.vocab,
|
||||
|
|
@ -4585,7 +4545,7 @@ int main(int argc, char ** argv) {
|
|||
ctx_server.params_base.n_predict,
|
||||
ctx_server.slots[0].n_ctx, // TODO: there should be a better way
|
||||
ctx_server.params_base.spm_infill,
|
||||
tokenized_prompts[0]
|
||||
tokenized_prompts[0].get_text_tokens() // TODO: this could maybe be multimodal.
|
||||
);
|
||||
|
||||
std::vector<raw_buffer> files; // dummy
|
||||
|
|
@ -4634,7 +4594,7 @@ int main(int argc, char ** argv) {
|
|||
if (current_state == SERVER_STATE_READY) {
|
||||
model_meta = ctx_server.model_meta();
|
||||
}
|
||||
|
||||
bool has_mtmd = ctx_server.mctx != nullptr;
|
||||
json models = {
|
||||
{"models", {
|
||||
{
|
||||
|
|
@ -4646,7 +4606,7 @@ int main(int argc, char ** argv) {
|
|||
{"type", "model"},
|
||||
{"description", ""},
|
||||
{"tags", {""}},
|
||||
{"capabilities", {"completion"}},
|
||||
{"capabilities", has_mtmd ? json({"completion","multimodal"}) : json({"completion"})},
|
||||
{"parameters", ""},
|
||||
{"details", {
|
||||
{"parent_model", ""},
|
||||
|
|
@ -4763,7 +4723,7 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
}
|
||||
|
||||
auto tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true);
|
||||
auto tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true);
|
||||
for (const auto & tokens : tokenized_prompts) {
|
||||
// this check is necessary for models that do not add BOS token to the input
|
||||
if (tokens.empty()) {
|
||||
|
|
@ -4791,7 +4751,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
task.id = ctx_server.queue_tasks.get_new_id();
|
||||
task.index = i;
|
||||
task.prompt_tokens = server_tokens(tokenized_prompts[i], ctx_server.mctx != nullptr);
|
||||
task.prompt_tokens = std::move(tokenized_prompts[i]);
|
||||
|
||||
// OAI-compat
|
||||
task.params.oaicompat = oaicompat;
|
||||
|
|
@ -4878,7 +4838,10 @@ int main(int argc, char ** argv) {
|
|||
return;
|
||||
}
|
||||
|
||||
llama_tokens tokenized_query = tokenize_input_prompts(ctx_server.vocab, query, /* add_special */ false, true)[0];
|
||||
std::vector<server_tokens> tokenized_queries = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, query, /* add_special */ false, true);
|
||||
if (tokenized_queries.size() != 1) {
|
||||
res_error(res, format_error_response("\"query\" must contain only a single prompt", ERROR_TYPE_INVALID_REQUEST));
|
||||
}
|
||||
|
||||
// create and queue the task
|
||||
json responses = json::array();
|
||||
|
|
@ -4886,14 +4849,14 @@ int main(int argc, char ** argv) {
|
|||
std::unordered_set<int> task_ids;
|
||||
{
|
||||
std::vector<server_task> tasks;
|
||||
auto tokenized_docs = tokenize_input_prompts(ctx_server.vocab, documents, /* add_special */ false, true);
|
||||
auto tokenized_docs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, documents, /* add_special */ false, true);
|
||||
tasks.reserve(tokenized_docs.size());
|
||||
for (size_t i = 0; i < tokenized_docs.size(); i++) {
|
||||
auto tmp = format_rerank(ctx_server.vocab, tokenized_query, tokenized_docs[i]);
|
||||
auto tmp = format_rerank(ctx_server.vocab, tokenized_queries[0], tokenized_docs[i]);
|
||||
server_task task = server_task(SERVER_TASK_TYPE_RERANK);
|
||||
task.id = ctx_server.queue_tasks.get_new_id();
|
||||
task.index = i;
|
||||
task.prompt_tokens = server_tokens(tmp, ctx_server.mctx != nullptr);
|
||||
task.prompt_tokens = std::move(tmp);
|
||||
tasks.push_back(std::move(task));
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -6,6 +6,8 @@ from utils import *
|
|||
|
||||
server = ServerPreset.tinyllama2()
|
||||
|
||||
JSON_MULTIMODAL_KEY = "multimodal_data"
|
||||
JSON_PROMPT_STRING_KEY = "prompt_string"
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def create_server():
|
||||
|
|
@ -231,6 +233,28 @@ def test_nocache_long_input_prompt():
|
|||
})
|
||||
assert res.status_code == 400
|
||||
|
||||
def test_json_prompt_no_mtmd():
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"prompt": { JSON_PROMPT_STRING_KEY: "I believe the meaning of life is" },
|
||||
"seed": 42,
|
||||
"temperature": 1.0,
|
||||
"cache_prompt": False,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
|
||||
def test_json_prompt_mtm_error_when_not_supported():
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"prompt": { JSON_PROMPT_STRING_KEY: "I believe the meaning of life is <__media__>", JSON_MULTIMODAL_KEY: "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=" },
|
||||
"seed": 42,
|
||||
"temperature": 1.0,
|
||||
"cache_prompt": False,
|
||||
})
|
||||
# MTMD is disabled on this model, so this should fail.
|
||||
assert res.status_code != 200
|
||||
|
||||
def test_completion_with_tokens_input():
|
||||
global server
|
||||
|
|
@ -269,6 +293,20 @@ def test_completion_with_tokens_input():
|
|||
assert len(res.body) == 2
|
||||
assert res.body[0]["content"] == res.body[1]["content"]
|
||||
|
||||
# mixed JSON and tokens
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"prompt": [
|
||||
tokens,
|
||||
{
|
||||
JSON_PROMPT_STRING_KEY: "I believe the meaning of life is",
|
||||
},
|
||||
],
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert type(res.body) == list
|
||||
assert len(res.body) == 2
|
||||
assert res.body[0]["content"] == res.body[1]["content"]
|
||||
|
||||
# mixed string and tokens in one sequence
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"prompt": [1, 2, 3, 4, 5, 6, prompt_str, 7, 8, 9, 10, prompt_str],
|
||||
|
|
|
|||
|
|
@ -10,21 +10,48 @@ IMG_URL_1 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/9
|
|||
|
||||
response = requests.get(IMG_URL_0)
|
||||
response.raise_for_status() # Raise an exception for bad status codes
|
||||
IMG_BASE64_0 = "data:image/png;base64," + base64.b64encode(response.content).decode("utf-8")
|
||||
IMG_BASE64_URI_0 = "data:image/png;base64," + base64.b64encode(response.content).decode("utf-8")
|
||||
IMG_BASE64_0 = base64.b64encode(response.content).decode("utf-8")
|
||||
|
||||
response = requests.get(IMG_URL_1)
|
||||
response.raise_for_status() # Raise an exception for bad status codes
|
||||
IMG_BASE64_URI_1 = "data:image/png;base64," + base64.b64encode(response.content).decode("utf-8")
|
||||
IMG_BASE64_1 = base64.b64encode(response.content).decode("utf-8")
|
||||
|
||||
JSON_MULTIMODAL_KEY = "multimodal_data"
|
||||
JSON_PROMPT_STRING_KEY = "prompt_string"
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def create_server():
|
||||
global server
|
||||
server = ServerPreset.tinygemma3()
|
||||
|
||||
def test_models_supports_multimodal_capability():
|
||||
global server
|
||||
server.start() # vision model may take longer to load due to download size
|
||||
res = server.make_request("GET", "/models", data={})
|
||||
assert res.status_code == 200
|
||||
model_info = res.body["models"][0]
|
||||
print(model_info)
|
||||
assert "completion" in model_info["capabilities"]
|
||||
assert "multimodal" in model_info["capabilities"]
|
||||
|
||||
def test_v1_models_supports_multimodal_capability():
|
||||
global server
|
||||
server.start() # vision model may take longer to load due to download size
|
||||
res = server.make_request("GET", "/v1/models", data={})
|
||||
assert res.status_code == 200
|
||||
model_info = res.body["models"][0]
|
||||
print(model_info)
|
||||
assert "completion" in model_info["capabilities"]
|
||||
assert "multimodal" in model_info["capabilities"]
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"prompt, image_url, success, re_content",
|
||||
[
|
||||
# test model is trained on CIFAR-10, but it's quite dumb due to small size
|
||||
("What is this:\n", IMG_URL_0, True, "(cat)+"),
|
||||
("What is this:\n", "IMG_BASE64_0", True, "(cat)+"), # exceptional, so that we don't cog up the log
|
||||
("What is this:\n", "IMG_BASE64_URI_0", True, "(cat)+"), # exceptional, so that we don't cog up the log
|
||||
("What is this:\n", IMG_URL_1, True, "(frog)+"),
|
||||
("Test test\n", IMG_URL_1, True, "(frog)+"), # test invalidate cache
|
||||
("What is this:\n", "malformed", False, None),
|
||||
|
|
@ -36,8 +63,8 @@ def create_server():
|
|||
def test_vision_chat_completion(prompt, image_url, success, re_content):
|
||||
global server
|
||||
server.start(timeout_seconds=60) # vision model may take longer to load due to download size
|
||||
if image_url == "IMG_BASE64_0":
|
||||
image_url = IMG_BASE64_0
|
||||
if image_url == "IMG_BASE64_URI_0":
|
||||
image_url = IMG_BASE64_URI_0
|
||||
res = server.make_request("POST", "/chat/completions", data={
|
||||
"temperature": 0.0,
|
||||
"top_k": 1,
|
||||
|
|
@ -58,3 +85,61 @@ def test_vision_chat_completion(prompt, image_url, success, re_content):
|
|||
else:
|
||||
assert res.status_code != 200
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"prompt, image_data, success, re_content",
|
||||
[
|
||||
# test model is trained on CIFAR-10, but it's quite dumb due to small size
|
||||
("What is this: <__media__>\n", IMG_BASE64_0, True, "(cat)+"),
|
||||
("What is this: <__media__>\n", IMG_BASE64_1, True, "(frog)+"),
|
||||
("What is this: <__media__>\n", "malformed", False, None), # non-image data
|
||||
("What is this:\n", "", False, None), # empty string
|
||||
]
|
||||
)
|
||||
def test_vision_completion(prompt, image_data, success, re_content):
|
||||
global server
|
||||
server.start() # vision model may take longer to load due to download size
|
||||
res = server.make_request("POST", "/completions", data={
|
||||
"temperature": 0.0,
|
||||
"top_k": 1,
|
||||
"prompt": { JSON_PROMPT_STRING_KEY: prompt, JSON_MULTIMODAL_KEY: [ image_data ] },
|
||||
})
|
||||
if success:
|
||||
assert res.status_code == 200
|
||||
content = res.body["content"]
|
||||
assert match_regex(re_content, content)
|
||||
else:
|
||||
assert res.status_code != 200
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"prompt, image_data, success",
|
||||
[
|
||||
# test model is trained on CIFAR-10, but it's quite dumb due to small size
|
||||
("What is this: <__media__>\n", IMG_BASE64_0, True), # exceptional, so that we don't cog up the log
|
||||
("What is this: <__media__>\n", IMG_BASE64_1, True),
|
||||
("What is this: <__media__>\n", "malformed", False), # non-image data
|
||||
("What is this:\n", "base64", False), # non-image data
|
||||
]
|
||||
)
|
||||
def test_vision_embeddings(prompt, image_data, success):
|
||||
global server
|
||||
server.server_embeddings=True
|
||||
server.n_batch=512
|
||||
server.start() # vision model may take longer to load due to download size
|
||||
res = server.make_request("POST", "/embeddings", data={
|
||||
"content": [
|
||||
{ JSON_PROMPT_STRING_KEY: prompt, JSON_MULTIMODAL_KEY: [ image_data ] },
|
||||
{ JSON_PROMPT_STRING_KEY: prompt, JSON_MULTIMODAL_KEY: [ image_data ] },
|
||||
{ JSON_PROMPT_STRING_KEY: prompt, },
|
||||
],
|
||||
})
|
||||
if success:
|
||||
assert res.status_code == 200
|
||||
content = res.body
|
||||
# Ensure embeddings are stable when multimodal.
|
||||
assert content[0]['embedding'] == content[1]['embedding']
|
||||
# Ensure embeddings without multimodal but same prompt do not match multimodal embeddings.
|
||||
assert content[0]['embedding'] != content[2]['embedding']
|
||||
else:
|
||||
assert res.status_code != 200
|
||||
|
|
|
|||
|
|
@ -123,6 +123,19 @@ static bool json_is_array_of_mixed_numbers_strings(const json & data) {
|
|||
return false;
|
||||
}
|
||||
|
||||
// does array have any individual integers/tokens?
|
||||
static bool json_is_array_and_contains_numbers(const json & data) {
|
||||
if (data.is_array()) {
|
||||
for (const auto & e : data) {
|
||||
if (e.is_number_integer()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// get value by path(key1 / key2)
|
||||
static json json_get_nested_values(const std::vector<std::string> & paths, const json & js) {
|
||||
json result = json::object();
|
||||
|
|
@ -186,48 +199,6 @@ static llama_tokens tokenize_mixed(const llama_vocab * vocab, const json & json_
|
|||
return prompt_tokens;
|
||||
}
|
||||
|
||||
/**
|
||||
* break the input "prompt" object into multiple prompt if needed, then tokenize them
|
||||
* this supports these cases:
|
||||
* - "prompt": "string"
|
||||
* - "prompt": [12, 34, 56]
|
||||
* - "prompt": [12, 34, "string", 56, 78]
|
||||
* and multiple prompts (multi-tasks):
|
||||
* - "prompt": ["string1", "string2"]
|
||||
* - "prompt": ["string1", [12, 34, 56]]
|
||||
* - "prompt": [[12, 34, 56], [78, 90, 12]]
|
||||
* - "prompt": [[12, 34, "string", 56, 78], [12, 34, 56]]
|
||||
*/
|
||||
static std::vector<llama_tokens> tokenize_input_prompts(const llama_vocab * vocab, const json & json_prompt, bool add_special, bool parse_special) {
|
||||
std::vector<llama_tokens> result;
|
||||
if (json_prompt.is_string() || json_is_array_of_mixed_numbers_strings(json_prompt)) {
|
||||
// string or mixed
|
||||
result.push_back(tokenize_mixed(vocab, json_prompt, add_special, parse_special));
|
||||
} else if (json_is_array_of_numbers(json_prompt)) {
|
||||
// array of tokens
|
||||
result.push_back(json_prompt.get<llama_tokens>());
|
||||
} else if (json_prompt.is_array()) {
|
||||
// array of prompts
|
||||
result.reserve(json_prompt.size());
|
||||
for (const auto & p : json_prompt) {
|
||||
if (p.is_string() || json_is_array_of_mixed_numbers_strings(p)) {
|
||||
result.push_back(tokenize_mixed(vocab, p, add_special, parse_special));
|
||||
} else if (json_is_array_of_numbers(p)) {
|
||||
// array of tokens
|
||||
result.push_back(p.get<llama_tokens>());
|
||||
} else {
|
||||
throw std::runtime_error("element of \"prompt\" must be a string, an list of tokens, or a list of mixed strings & tokens");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
throw std::runtime_error("\"prompt\" must be a string, an list of tokens, a list of mixed strings & tokens, or a list of prompts");
|
||||
}
|
||||
if (result.empty()) {
|
||||
throw std::runtime_error("\"prompt\" must not be empty");
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// return the last index of character that can form a valid string
|
||||
// if the last character is potentially cut in half, return the index before the cut
|
||||
// if validate_utf8(text) == text.size(), then the whole text is valid utf8
|
||||
|
|
@ -262,35 +233,6 @@ static size_t validate_utf8(const std::string& text) {
|
|||
// template utils
|
||||
//
|
||||
|
||||
// format rerank task: [BOS]query[EOS][SEP]doc[EOS]
|
||||
static llama_tokens format_rerank(const struct llama_vocab * vocab, const llama_tokens & query, const llama_tokens & doc) {
|
||||
llama_tokens result;
|
||||
|
||||
// Get EOS token - use SEP token as fallback if EOS is not available
|
||||
llama_token eos_token = llama_vocab_eos(vocab);
|
||||
if (eos_token == LLAMA_TOKEN_NULL) {
|
||||
eos_token = llama_vocab_sep(vocab);
|
||||
}
|
||||
|
||||
result.reserve(doc.size() + query.size() + 4);
|
||||
if (llama_vocab_get_add_bos(vocab)) {
|
||||
result.push_back(llama_vocab_bos(vocab));
|
||||
}
|
||||
result.insert(result.end(), query.begin(), query.end());
|
||||
if (llama_vocab_get_add_eos(vocab)) {
|
||||
result.push_back(eos_token);
|
||||
}
|
||||
if (llama_vocab_get_add_sep(vocab)) {
|
||||
result.push_back(llama_vocab_sep(vocab));
|
||||
}
|
||||
result.insert(result.end(), doc.begin(), doc.end());
|
||||
if (llama_vocab_get_add_eos(vocab)) {
|
||||
result.push_back(eos_token);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// format infill task
|
||||
static llama_tokens format_infill(
|
||||
const llama_vocab * vocab,
|
||||
|
|
@ -1186,6 +1128,24 @@ public:
|
|||
}
|
||||
}
|
||||
|
||||
// appends server tokens, updates the media map. copies media chunks.
|
||||
void push_back(server_tokens & tokens) {
|
||||
size_t start_pos = size();
|
||||
for (size_t i = 0; i < tokens.size(); i++) {
|
||||
push_back(tokens[i]);
|
||||
}
|
||||
if (tokens.has_mtmd) {
|
||||
// Assert if we are copying MTMD chunks to a server_tokens that does not have mtmd.
|
||||
// We could also just check, but this will prevent silently dropping MTMD data.
|
||||
GGML_ASSERT(has_mtmd);
|
||||
for (auto it = tokens.map_pos_to_media.begin(); it != tokens.map_pos_to_media.end(); ) {
|
||||
auto chunk = tokens.map_pos_to_media[it->first].get();
|
||||
mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk));
|
||||
map_pos_to_media[start_pos+it->first] = std::move(new_chunk);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// for compatibility with context shift and prompt truncation
|
||||
void insert(const llama_tokens & inp_tokens) {
|
||||
GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled
|
||||
|
|
@ -1356,3 +1316,137 @@ static std::string fnv_hash(const uint8_t * data, size_t len) {
|
|||
}
|
||||
return std::to_string(hash);
|
||||
}
|
||||
|
||||
|
||||
// format rerank task: [BOS]query[EOS][SEP]doc[EOS].
|
||||
static server_tokens format_rerank(const struct llama_vocab * vocab, server_tokens & query, server_tokens & doc) {
|
||||
server_tokens result = {};
|
||||
|
||||
// Get EOS token - use SEP token as fallback if EOS is not available
|
||||
llama_token eos_token = llama_vocab_eos(vocab);
|
||||
if (eos_token == LLAMA_TOKEN_NULL) {
|
||||
eos_token = llama_vocab_sep(vocab);
|
||||
}
|
||||
if (llama_vocab_get_add_bos(vocab)) {
|
||||
result.push_back(llama_vocab_bos(vocab));
|
||||
}
|
||||
result.push_back(query);
|
||||
if (llama_vocab_get_add_eos(vocab)) {
|
||||
result.push_back(eos_token);
|
||||
}
|
||||
if (llama_vocab_get_add_sep(vocab)) {
|
||||
result.push_back(llama_vocab_sep(vocab));
|
||||
}
|
||||
result.push_back(doc);
|
||||
if (llama_vocab_get_add_eos(vocab)) {
|
||||
result.push_back(eos_token);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
static server_tokens process_mtmd_prompt(mtmd_context * mctx, std::string prompt, std::vector<raw_buffer> files) {
|
||||
mtmd::bitmaps bitmaps;
|
||||
for (auto & file : files) {
|
||||
mtmd::bitmap bmp(mtmd_helper_bitmap_init_from_buf(mctx, file.data(), file.size()));
|
||||
if (!bmp.ptr) {
|
||||
throw std::runtime_error("Failed to load image or audio file");
|
||||
}
|
||||
// calculate bitmap hash (for KV caching)
|
||||
std::string hash = fnv_hash(bmp.data(), bmp.n_bytes());
|
||||
bmp.set_id(hash.c_str());
|
||||
bitmaps.entries.push_back(std::move(bmp));
|
||||
}
|
||||
// process prompt
|
||||
std::vector<server_tokens> inputs;
|
||||
// multimodal
|
||||
mtmd_input_text inp_txt = {
|
||||
prompt.c_str(),
|
||||
/* add_special */ true,
|
||||
/* parse_special */ true,
|
||||
};
|
||||
mtmd::input_chunks chunks(mtmd_input_chunks_init());
|
||||
auto bitmaps_c_ptr = bitmaps.c_ptr();
|
||||
int32_t tokenized = mtmd_tokenize(mctx,
|
||||
chunks.ptr.get(),
|
||||
&inp_txt,
|
||||
bitmaps_c_ptr.data(),
|
||||
bitmaps_c_ptr.size());
|
||||
if (tokenized != 0) {
|
||||
throw std::runtime_error("Failed to tokenize prompt");
|
||||
}
|
||||
auto result = server_tokens(chunks, true);
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* break the input "prompt" object into multiple prompt if needed, then tokenize them
|
||||
* use tokenize_input_prompts() if the input could be an array.
|
||||
* this supports these cases:
|
||||
* - "prompt": "string"
|
||||
* - "prompt": [12, 34, 56]
|
||||
* - "prompt": [12, 34, "string", 56, 78]
|
||||
* - "prompt": { "prompt_string": "string", "multimodal_data": [ "base64" ] }
|
||||
*/
|
||||
static server_tokens tokenize_input_subprompt(const llama_vocab * vocab, mtmd_context * mctx, const json & json_prompt, bool add_special, bool parse_special) {
|
||||
constexpr char JSON_STRING_PROMPT_KEY[] = "prompt_string";
|
||||
constexpr char JSON_MTMD_DATA_KEY[] = "multimodal_data";
|
||||
const bool has_mtmd = mctx != nullptr;
|
||||
if (json_prompt.is_string() || json_is_array_of_mixed_numbers_strings(json_prompt)) {
|
||||
// string or mixed
|
||||
llama_tokens tmp = tokenize_mixed(vocab, json_prompt, add_special, parse_special);
|
||||
return server_tokens(tmp, false);
|
||||
} else if (json_is_array_of_numbers(json_prompt)) {
|
||||
// array of tokens
|
||||
llama_tokens tmp = json_prompt.get<llama_tokens>();
|
||||
return server_tokens(tmp, false);
|
||||
} else if (json_prompt.contains(JSON_STRING_PROMPT_KEY)) {
|
||||
// JSON object with prompt key.
|
||||
if (json_prompt.contains(JSON_MTMD_DATA_KEY)) {
|
||||
if (!has_mtmd)
|
||||
throw std::runtime_error("Multimodal data provided, but model does not support multimodal requests.");
|
||||
|
||||
// JSON object with prompt and multimodal key.
|
||||
std::vector<raw_buffer> files;
|
||||
for (const auto & entry : json_prompt.at(JSON_MTMD_DATA_KEY)) {
|
||||
files.push_back(base64_decode(entry));
|
||||
}
|
||||
return process_mtmd_prompt(mctx, json_prompt.at(JSON_STRING_PROMPT_KEY), files);
|
||||
} else {
|
||||
// Not multimodal, but contains a subobject.
|
||||
llama_tokens tmp = tokenize_mixed(vocab, json_prompt.at(JSON_STRING_PROMPT_KEY), add_special, parse_special);
|
||||
return server_tokens(tmp, false);
|
||||
}
|
||||
} else {
|
||||
throw std::runtime_error("\"prompt\" elements must be a string, a list of tokens, a JSON object containing a prompt string, or a list of mixed strings & tokens.");
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* break the input "prompt" object into multiple prompt if needed, then tokenize them
|
||||
* this supports these cases:
|
||||
* - "prompt": "string"
|
||||
* - "prompt": [12, 34, 56]
|
||||
* - "prompt": [12, 34, "string", 56, 78]
|
||||
* - "prompt": { "prompt_string": "string", "multimodal_data": [ "base64" ] }
|
||||
* and multiple prompts (multi-tasks):
|
||||
* - "prompt": ["string1", "string2"]
|
||||
* - "prompt": ["string1", [12, 34, 56]]
|
||||
* - "prompt": [[12, 34, 56], [78, 90, 12]]
|
||||
* - "prompt": [[12, 34, "string", 56, 78], [12, 34, 56], { "prompt_string": "string", "multimodal_data": [ "base64" ]}]
|
||||
*/
|
||||
static std::vector<server_tokens> tokenize_input_prompts(const llama_vocab * vocab, mtmd_context * mctx, const json & json_prompt, bool add_special, bool parse_special) {
|
||||
std::vector<server_tokens> result;
|
||||
if (json_prompt.is_array() && !json_is_array_and_contains_numbers(json_prompt)) {
|
||||
result.reserve(json_prompt.size());
|
||||
for (const auto & p : json_prompt) {
|
||||
result.push_back(tokenize_input_subprompt(vocab, mctx, p,add_special, parse_special));
|
||||
}
|
||||
} else {
|
||||
result.push_back(tokenize_input_subprompt(vocab, mctx, json_prompt, add_special, parse_special));
|
||||
}
|
||||
if (result.empty()) {
|
||||
throw std::runtime_error("\"prompt\" must not be empty");
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue