mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-10 17:14:36 +00:00
Merge branch 'upstream' into concedo_experimental
# Conflicts: # .github/workflows/build.yml # README.md # ci/run.sh # docs/build.md # examples/CMakeLists.txt # examples/parallel/parallel.cpp # ggml/CMakeLists.txt # ggml/src/CMakeLists.txt # scripts/server-bench.py # src/llama-kv-cache-unified.cpp # tests/test-backend-ops.cpp # tools/batched-bench/batched-bench.cpp # tools/server/README.md
This commit is contained in:
commit
bdff33e0de
47 changed files with 3128 additions and 509 deletions
|
@ -1466,6 +1466,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
params.swa_full = true;
|
||||
}
|
||||
).set_env("LLAMA_ARG_SWA_FULL"));
|
||||
add_opt(common_arg(
|
||||
{"--kv-unified", "-kvu"},
|
||||
string_format("use single unified KV buffer for the KV cache of all sequences (default: %s)\n"
|
||||
"[(more info)](https://github.com/ggml-org/llama.cpp/pull/14363)", params.kv_unified ? "true" : "false"),
|
||||
[](common_params & params) {
|
||||
params.kv_unified = true;
|
||||
}
|
||||
).set_env("LLAMA_ARG_KV_SPLIT"));
|
||||
add_opt(common_arg(
|
||||
{"--no-context-shift"},
|
||||
string_format("disables context shift on infinite text generation (default: %s)", params.ctx_shift ? "disabled" : "enabled"),
|
||||
|
@ -3425,5 +3433,34 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
|
||||
// diffusion parameters
|
||||
add_opt(common_arg(
|
||||
{ "--diffusion-steps" }, "N",
|
||||
string_format("number of diffusion steps (default: %d)", params.diffusion.steps),
|
||||
[](common_params & params, int value) { params.diffusion.steps = value; }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
add_opt(common_arg(
|
||||
{ "--diffusion-eps" }, "F",
|
||||
string_format("epsilon for timesteps (default: %.6f)", (double) params.diffusion.eps),
|
||||
[](common_params & params, const std::string & value) { params.diffusion.eps = std::stof(value); }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
add_opt(common_arg(
|
||||
{ "--diffusion-algorithm" }, "N",
|
||||
string_format("diffusion algorithm: 0=ORIGIN, 1=MASKGIT_PLUS, 2=TOPK_MARGIN, 3=ENTROPY (default: %d)",
|
||||
params.diffusion.algorithm),
|
||||
[](common_params & params, int value) { params.diffusion.algorithm = value; }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
add_opt(common_arg(
|
||||
{ "--diffusion-alg-temp" }, "F",
|
||||
string_format("algorithm temperature (default: %.3f)", (double) params.diffusion.alg_temp),
|
||||
[](common_params & params, const std::string & value) { params.diffusion.alg_temp = std::stof(value); }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
add_opt(common_arg(
|
||||
{ "--diffusion-visual" },
|
||||
string_format("enable visual diffusion mode (show progressive generation) (default: %s)",
|
||||
params.diffusion.visual_mode ? "true" : "false"),
|
||||
[](common_params & params) { params.diffusion.visual_mode = true; }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
|
||||
return ctx_arg;
|
||||
}
|
||||
|
|
|
@ -1013,15 +1013,21 @@ struct common_init_result common_init_from_params(common_params & params) {
|
|||
params.sampling.ignore_eos = false;
|
||||
}
|
||||
|
||||
if (params.sampling.ignore_eos) {
|
||||
for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) {
|
||||
if (llama_vocab_is_eog(vocab, i)) {
|
||||
LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(lctx, i).c_str(), -INFINITY);
|
||||
params.sampling.logit_bias.push_back({i, -INFINITY});
|
||||
}
|
||||
// initialize once
|
||||
for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) {
|
||||
if (llama_vocab_is_eog(vocab, i)) {
|
||||
LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(lctx, i).c_str(), -INFINITY);
|
||||
params.sampling.logit_bias_eog.push_back({i, -INFINITY});
|
||||
}
|
||||
}
|
||||
|
||||
if (params.sampling.ignore_eos) {
|
||||
// add EOG biases to the active set of logit biases
|
||||
params.sampling.logit_bias.insert(
|
||||
params.sampling.logit_bias.end(),
|
||||
params.sampling.logit_bias_eog.begin(), params.sampling.logit_bias_eog.end());
|
||||
}
|
||||
|
||||
if (params.sampling.penalty_last_n == -1) {
|
||||
LOG_INF("%s: setting penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
|
||||
params.sampling.penalty_last_n = llama_n_ctx(lctx);
|
||||
|
@ -1165,6 +1171,7 @@ struct llama_context_params common_context_params_to_llama(const common_params &
|
|||
cparams.no_perf = params.no_perf;
|
||||
cparams.op_offload = !params.no_op_offload;
|
||||
cparams.swa_full = params.swa_full;
|
||||
cparams.kv_unified = params.kv_unified;
|
||||
|
||||
cparams.type_k = params.cache_type_k;
|
||||
cparams.type_v = params.cache_type_v;
|
||||
|
|
|
@ -77,6 +77,7 @@ enum llama_example {
|
|||
LLAMA_EXAMPLE_LOOKUP,
|
||||
LLAMA_EXAMPLE_PARALLEL,
|
||||
LLAMA_EXAMPLE_TTS,
|
||||
LLAMA_EXAMPLE_DIFFUSION,
|
||||
|
||||
LLAMA_EXAMPLE_COUNT,
|
||||
};
|
||||
|
@ -173,7 +174,8 @@ struct common_params_sampling {
|
|||
std::vector<common_grammar_trigger> grammar_triggers; // optional triggers (for lazy grammars)
|
||||
std::set<llama_token> preserved_tokens;
|
||||
|
||||
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
|
||||
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
|
||||
std::vector<llama_logit_bias> logit_bias_eog; // pre-calculated logit biases for EOG tokens
|
||||
|
||||
// print the parameters into a string
|
||||
std::string print() const;
|
||||
|
@ -213,6 +215,14 @@ struct common_params_vocoder {
|
|||
bool use_guide_tokens = false; // enable guide tokens to improve TTS accuracy // NOLINT
|
||||
};
|
||||
|
||||
struct common_params_diffusion {
|
||||
int32_t steps = 64; // number of diffusion steps
|
||||
float eps = 1e-3f; // epsilon for timesteps
|
||||
int32_t algorithm = 0; // diffusion algorithm (0=ORIGIN, 1=MASKGIT_PLUS, 2=TOPK_MARGIN, 3=ENTROPY)
|
||||
float alg_temp = 0.0f; // algorithm temperature
|
||||
bool visual_mode = false; // show progressive diffusion on screen
|
||||
};
|
||||
|
||||
enum common_reasoning_format {
|
||||
COMMON_REASONING_FORMAT_NONE,
|
||||
COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY, // Extract thinking tag contents and return as `message.reasoning_content`, or leave inline in <think> tags in stream mode
|
||||
|
@ -264,6 +274,7 @@ struct common_params {
|
|||
struct common_params_sampling sampling;
|
||||
struct common_params_speculative speculative;
|
||||
struct common_params_vocoder vocoder;
|
||||
struct common_params_diffusion diffusion;
|
||||
|
||||
struct common_params_model model;
|
||||
|
||||
|
@ -326,6 +337,7 @@ struct common_params {
|
|||
bool no_perf = false; // disable performance metrics
|
||||
bool ctx_shift = true; // context shift on inifinite text generation
|
||||
bool swa_full = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
|
||||
bool kv_unified = false; // enable unified KV cache
|
||||
|
||||
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
|
||||
bool use_mmap = true; // use mmap for faster loads
|
||||
|
|
|
@ -669,6 +669,36 @@ class TextModel(ModelBase):
|
|||
# NOTE: if you get an error here, you need to update the convert_hf_to_gguf_update.py script
|
||||
# or pull the latest version of the model from Huggingface
|
||||
# don't edit the hashes manually!
|
||||
if chkhsh == "b6e8e1518dc4305be2fe39c313ed643381c4da5db34a98f6a04c093f8afbe99b":
|
||||
# ref: https://huggingface.co/THUDM/glm-4-9b-chat
|
||||
res = "chatglm-bpe"
|
||||
if chkhsh == "81d72c7348a9f0ebe86f23298d37debe0a5e71149e29bd283904c02262b27516":
|
||||
# ref: https://huggingface.co/THUDM/glm-4-9b-chat
|
||||
res = "chatglm-bpe"
|
||||
if chkhsh == "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2":
|
||||
# ref: https://huggingface.co/THUDM/glm-4-9b-hf
|
||||
res = "glm4"
|
||||
if chkhsh == "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35":
|
||||
# ref: https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0
|
||||
res = "minerva-7b"
|
||||
if chkhsh == "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664":
|
||||
# ref: https://huggingface.co/tencent/Hunyuan-A13B-Instruct
|
||||
res = "hunyuan"
|
||||
if chkhsh == "a6b57017d60e6edb4d88ecc2845188e0eb333a70357e45dcc9b53964a73bbae6":
|
||||
# ref: https://huggingface.co/tiiuae/Falcon-H1-0.5B-Base
|
||||
res = "falcon-h1"
|
||||
if chkhsh == "60476e1243776c4fb1b993dbd7a5f15ac22f83c80afdf425fa5ae01c8d44ef86":
|
||||
# ref: https://huggingface.co/tiiuae/Falcon-H1-1B-Base
|
||||
res = "falcon-h1"
|
||||
if chkhsh == "3eda48b4c4dc7de733d1a8b3e3b4a85243dbbf704da2ee9d42c6beced8897896":
|
||||
# ref: https://huggingface.co/tiiuae/Falcon-H1-7B-Base
|
||||
res = "falcon-h1"
|
||||
if chkhsh == "48f8e02c0359c0bbdd82f26909171fac1c18a457bb47573ed1fe3bbb2c1cfd4b":
|
||||
# ref: https://huggingface.co/tiiuae/Falcon-H1-34B-Base
|
||||
res = "falcon-h1"
|
||||
if chkhsh == "81212dc7cdb7e0c1074ca62c5aeab0d43c9f52b8a737be7b12a777c953027890":
|
||||
# ref: https://huggingface.co/moonshotai/Kimi-K2-Base
|
||||
res = "kimi-k2"
|
||||
if chkhsh == "0ef9807a4087ebef797fc749390439009c3b9eda9ad1a097abbe738f486c01e5":
|
||||
# ref: https://huggingface.co/meta-llama/Meta-Llama-3-8B
|
||||
res = "llama-bpe"
|
||||
|
@ -804,45 +834,15 @@ class TextModel(ModelBase):
|
|||
if chkhsh == "d5f1dd6f980fec569fb218a81a7658ac45fc56b38c5a0adeb1c232fbe04ef5ec":
|
||||
# ref: https://huggingface.co/ByteDance-Seed/Seed-Coder-8B-Base
|
||||
res = "seed-coder"
|
||||
if chkhsh == "b6e8e1518dc4305be2fe39c313ed643381c4da5db34a98f6a04c093f8afbe99b":
|
||||
# ref: https://huggingface.co/THUDM/glm-4-9b-chat
|
||||
res = "chatglm-bpe"
|
||||
if chkhsh == "81d72c7348a9f0ebe86f23298d37debe0a5e71149e29bd283904c02262b27516":
|
||||
# ref: https://huggingface.co/THUDM/glm-4-9b-chat
|
||||
res = "chatglm-bpe"
|
||||
if chkhsh == "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2":
|
||||
# ref: https://huggingface.co/THUDM/glm-4-9b-hf
|
||||
res = "glm4"
|
||||
if chkhsh == "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35":
|
||||
# ref: https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0
|
||||
res = "minerva-7b"
|
||||
if chkhsh == "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664":
|
||||
# ref: https://huggingface.co/tencent/Hunyuan-A13B-Instruct
|
||||
res = "hunyuan"
|
||||
if chkhsh == "b0a6b1c0bd5998ebd9df08611efde34a4ff03faed45ae09c43e6b31ebd4b94cf":
|
||||
# ref: https://huggingface.co/skt/A.X-4.0
|
||||
res = "a.x-4.0"
|
||||
if chkhsh == "a6b57017d60e6edb4d88ecc2845188e0eb333a70357e45dcc9b53964a73bbae6":
|
||||
# ref: https://huggingface.co/tiiuae/Falcon-H1-0.5B-Base
|
||||
res = "falcon-h1"
|
||||
if chkhsh == "60476e1243776c4fb1b993dbd7a5f15ac22f83c80afdf425fa5ae01c8d44ef86":
|
||||
# ref: https://huggingface.co/tiiuae/Falcon-H1-1B-Base
|
||||
res = "falcon-h1"
|
||||
if chkhsh == "3eda48b4c4dc7de733d1a8b3e3b4a85243dbbf704da2ee9d42c6beced8897896":
|
||||
# ref: https://huggingface.co/tiiuae/Falcon-H1-7B-Base
|
||||
res = "falcon-h1"
|
||||
if chkhsh == "48f8e02c0359c0bbdd82f26909171fac1c18a457bb47573ed1fe3bbb2c1cfd4b":
|
||||
# ref: https://huggingface.co/tiiuae/Falcon-H1-34B-Base
|
||||
res = "falcon-h1"
|
||||
if chkhsh == "f6791d196f87ce6b56a7d234be618e0d58f8cda3549416635b2bebcd22cd95c4":
|
||||
# ref: https://huggingface.co/K-intelligence/Midm-2.0-Base-Instruct
|
||||
res = "midm-2.0"
|
||||
if chkhsh == "169bf0296a13c4d9b7672313f749eb36501d931022de052aad6e36f2bf34dd51":
|
||||
# ref: https://huggingface.co/LiquidAI/LFM2-Tokenizer
|
||||
res = "lfm2"
|
||||
if chkhsh == "81212dc7cdb7e0c1074ca62c5aeab0d43c9f52b8a737be7b12a777c953027890":
|
||||
# ref: https://huggingface.co/moonshotai/Kimi-K2-Base
|
||||
res = "kimi-k2"
|
||||
|
||||
if res is None:
|
||||
logger.warning("\n")
|
||||
|
@ -2778,6 +2778,76 @@ class Qwen2Model(TextModel):
|
|||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("DreamModel")
|
||||
class DreamModel(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.DREAM
|
||||
|
||||
def get_vocab_base(self) -> tuple[list[str], list[int], str]:
|
||||
tokens: list[str] = []
|
||||
toktypes: list[int] = []
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True)
|
||||
|
||||
vocab_dict = tokenizer.get_vocab()
|
||||
vocab_size = self.hparams.get("vocab_size", len(vocab_dict))
|
||||
assert max(vocab_dict.values()) < vocab_size
|
||||
|
||||
tokpre = self.get_vocab_base_pre(tokenizer)
|
||||
|
||||
reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in vocab_dict.items()}
|
||||
added_vocab = tokenizer.get_added_vocab()
|
||||
|
||||
for i in range(vocab_size):
|
||||
if i not in reverse_vocab:
|
||||
tokens.append(f"[PAD{i}]")
|
||||
toktypes.append(gguf.TokenType.UNUSED)
|
||||
elif reverse_vocab[i] in added_vocab:
|
||||
tokens.append(reverse_vocab[i])
|
||||
# Check if it's a special token - treat special tokens as CONTROL tokens
|
||||
if hasattr(tokenizer, 'added_tokens_decoder') and i in tokenizer.added_tokens_decoder:
|
||||
if tokenizer.added_tokens_decoder[i].special:
|
||||
toktypes.append(gguf.TokenType.CONTROL)
|
||||
else:
|
||||
toktypes.append(gguf.TokenType.USER_DEFINED)
|
||||
else:
|
||||
# Fallback: treat all added vocab as control tokens for special tokens like <|im_start|>
|
||||
toktypes.append(gguf.TokenType.CONTROL)
|
||||
else:
|
||||
tokens.append(reverse_vocab[i])
|
||||
toktypes.append(gguf.TokenType.NORMAL)
|
||||
|
||||
return tokens, toktypes, tokpre
|
||||
|
||||
def set_vocab(self):
|
||||
try:
|
||||
self._set_vocab_sentencepiece()
|
||||
except FileNotFoundError:
|
||||
self._set_vocab_gpt2()
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
self._try_set_pooling_type()
|
||||
|
||||
# Dream models use non-causal attention for diffusion
|
||||
self.gguf_writer.add_causal_attention(False)
|
||||
# Handle RoPE scaling similar to Qwen2
|
||||
rope_scaling = self.hparams.get("rope_scaling") or {}
|
||||
if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling:
|
||||
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
|
||||
self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
|
||||
self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"])
|
||||
|
||||
# Add Dream-specific parameters
|
||||
mask_token_id = self.hparams.get("mask_token_id")
|
||||
if mask_token_id is not None:
|
||||
self.gguf_writer.add_mask_token_id(mask_token_id)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
# Dream model tensors should be mapped directly since it's the base model
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("Ernie4_5_ForCausalLM")
|
||||
class Ernie4_5Model(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.ERNIE4_5
|
||||
|
|
|
@ -232,7 +232,7 @@ for model in models:
|
|||
# generate the source code for the convert_hf_to_gguf.py:get_vocab_base_pre() function:
|
||||
|
||||
src_ifs = ""
|
||||
for model in [*all_models, *pre_computed_hashes]:
|
||||
for model in [*pre_computed_hashes, *all_models]:
|
||||
name = model["name"]
|
||||
tokt = model["tokt"]
|
||||
chkhsh = model.get("chkhsh")
|
||||
|
@ -240,11 +240,6 @@ for model in [*all_models, *pre_computed_hashes]:
|
|||
if tokt == TOKENIZER_TYPE.SPM or tokt == TOKENIZER_TYPE.UGM:
|
||||
continue
|
||||
|
||||
# Skip if the tokenizer folder does not exist or there are other download issues previously
|
||||
if not os.path.exists(f"models/tokenizers/{name}"):
|
||||
logger.warning(f"Directory for tokenizer {name} not found. Skipping...")
|
||||
continue
|
||||
|
||||
# create the tokenizer
|
||||
if chkhsh is not None:
|
||||
# if the model has a pre-computed hash, use it
|
||||
|
@ -254,6 +249,12 @@ for model in [*all_models, *pre_computed_hashes]:
|
|||
chkhsh = existing_models[name]
|
||||
else:
|
||||
# otherwise, compute the hash of the tokenizer
|
||||
|
||||
# Skip if the tokenizer folder does not exist or there are other download issues previously
|
||||
if not os.path.exists(f"models/tokenizers/{name}"):
|
||||
logger.warning(f"Directory for tokenizer {name} not found. Skipping...")
|
||||
continue
|
||||
|
||||
try:
|
||||
logger.info(f"Loading tokenizer from {f'models/tokenizers/{name}'}...")
|
||||
if name == "t5":
|
||||
|
|
5
examples/diffusion/CMakeLists.txt
Normal file
5
examples/diffusion/CMakeLists.txt
Normal file
|
@ -0,0 +1,5 @@
|
|||
set(TARGET llama-diffusion-cli)
|
||||
add_executable(${TARGET} diffusion-cli.cpp)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_link_libraries(${TARGET} PRIVATE llama common ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
507
examples/diffusion/diffusion-cli.cpp
Normal file
507
examples/diffusion/diffusion-cli.cpp
Normal file
|
@ -0,0 +1,507 @@
|
|||
#include "arg.h"
|
||||
#include "chat.h"
|
||||
#include "common.h"
|
||||
#include "llama.h"
|
||||
#include "log.h"
|
||||
|
||||
#include <limits.h>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <limits>
|
||||
#include <random>
|
||||
|
||||
typedef bool (*diffusion_step_callback_t)(int32_t step,
|
||||
int32_t total_steps,
|
||||
const llama_token * tokens,
|
||||
int32_t n_tokens,
|
||||
void * user_data);
|
||||
|
||||
enum diffusion_alg {
|
||||
DIFFUSION_ALG_ORIGIN = 0,
|
||||
DIFFUSION_ALG_MASKGIT_PLUS = 1,
|
||||
DIFFUSION_ALG_TOPK_MARGIN = 2,
|
||||
DIFFUSION_ALG_ENTROPY = 3,
|
||||
};
|
||||
|
||||
struct diffusion_params {
|
||||
int32_t steps;
|
||||
float eps;
|
||||
float temperature;
|
||||
float top_p;
|
||||
int32_t top_k;
|
||||
llama_token mask_token_id;
|
||||
enum diffusion_alg algorithm;
|
||||
float alg_temp;
|
||||
diffusion_step_callback_t step_callback;
|
||||
void * step_callback_user_data;
|
||||
int32_t seed;
|
||||
};
|
||||
|
||||
|
||||
static diffusion_params diffusion_default_params() {
|
||||
diffusion_params params = {};
|
||||
params.steps = 64;
|
||||
params.eps = 1e-3f;
|
||||
params.temperature = 0.2f;
|
||||
params.top_p = 0.95f;
|
||||
params.top_k = 0;
|
||||
params.mask_token_id = LLAMA_TOKEN_NULL;
|
||||
params.algorithm = DIFFUSION_ALG_ORIGIN;
|
||||
params.alg_temp = 0.0f;
|
||||
params.step_callback = nullptr;
|
||||
params.step_callback_user_data = nullptr;
|
||||
params.seed = 0;
|
||||
return params;
|
||||
}
|
||||
|
||||
static void diffusion_generate(llama_context * ctx,
|
||||
const llama_token * input_tokens,
|
||||
llama_token * output_tokens,
|
||||
int32_t n_input,
|
||||
int32_t max_length,
|
||||
struct diffusion_params params,
|
||||
int32_t & n_generated) {
|
||||
|
||||
n_generated = 0;
|
||||
if (!ctx || !input_tokens || !output_tokens || n_input <= 0 || max_length <= n_input) {
|
||||
return;
|
||||
}
|
||||
|
||||
const llama_model * model = llama_get_model(ctx);
|
||||
|
||||
// Initialize with input and pad with mask tokens
|
||||
std::copy(input_tokens, input_tokens + n_input, output_tokens);
|
||||
std::fill(output_tokens + n_input, output_tokens + max_length, params.mask_token_id);
|
||||
|
||||
std::mt19937 rng(params.seed);
|
||||
|
||||
std::vector<float> timesteps(params.steps + 1);
|
||||
for (int32_t i = 0; i <= params.steps; i++) {
|
||||
timesteps[i] = 1.0f - (float) i / params.steps * (1.0f - params.eps);
|
||||
}
|
||||
|
||||
llama_set_causal_attn(ctx, false);
|
||||
|
||||
int32_t n_vocab = llama_vocab_n_tokens(llama_model_get_vocab(model));
|
||||
|
||||
std::vector<llama_token_data> candidates(n_vocab);
|
||||
|
||||
std::vector<llama_token_data> conf_candidates;
|
||||
conf_candidates.reserve(max_length);
|
||||
|
||||
std::vector<int32_t> mask_positions;
|
||||
mask_positions.reserve(max_length);
|
||||
|
||||
struct llama_sampler * sampler = llama_sampler_chain_init(llama_sampler_chain_default_params());
|
||||
if (params.top_k > 0) {
|
||||
llama_sampler_chain_add(sampler, llama_sampler_init_top_k(params.top_k));
|
||||
}
|
||||
if (params.top_p < 1.0f) {
|
||||
llama_sampler_chain_add(sampler, llama_sampler_init_top_p(params.top_p, 1));
|
||||
}
|
||||
if (params.temperature > 0.0f) {
|
||||
llama_sampler_chain_add(sampler, llama_sampler_init_temp(params.temperature));
|
||||
}
|
||||
llama_sampler_chain_add(sampler, llama_sampler_init_dist(params.seed));
|
||||
|
||||
struct llama_sampler * dist_sampler = llama_sampler_init_dist(params.seed);
|
||||
|
||||
llama_batch batch = llama_batch_init(max_length, 0, 1);
|
||||
batch.n_tokens = max_length;
|
||||
|
||||
int64_t total_sampling_time = 0;
|
||||
int64_t total_time = 0;
|
||||
|
||||
int64_t time_start = ggml_time_us();
|
||||
for (int32_t step = 0; step < params.steps; step++) {
|
||||
if (params.step_callback) {
|
||||
if (!params.step_callback(step, params.steps, output_tokens, max_length, params.step_callback_user_data)) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
for (int32_t i = 0; i < max_length; i++) {
|
||||
batch.token[i] = output_tokens[i];
|
||||
batch.pos[i] = i;
|
||||
batch.n_seq_id[i] = 1;
|
||||
batch.seq_id[i][0] = 0;
|
||||
batch.logits[i] = 1;
|
||||
}
|
||||
|
||||
int ret = llama_decode(ctx, batch);
|
||||
if (ret != 0) {
|
||||
LOG_ERR("%s: failed to decode at step %d, ret = %d\n", __func__, step, ret);
|
||||
break;
|
||||
}
|
||||
|
||||
float * raw_logits = llama_get_logits(ctx);
|
||||
if (!raw_logits) {
|
||||
LOG_ERR("%s: failed to get logits at step %d\n", __func__, step);
|
||||
break;
|
||||
}
|
||||
|
||||
auto get_logits_for_pos = [&](int32_t pos) -> const float * {
|
||||
return pos == 0 ? raw_logits : raw_logits + (pos - 1) * n_vocab;
|
||||
};
|
||||
|
||||
int64_t time_start_sampling = ggml_time_us();
|
||||
|
||||
mask_positions.clear();
|
||||
for (int32_t i = 0; i < max_length; i++) {
|
||||
if (output_tokens[i] == params.mask_token_id) {
|
||||
mask_positions.push_back(i);
|
||||
}
|
||||
}
|
||||
|
||||
if (mask_positions.empty()) {
|
||||
break;
|
||||
}
|
||||
|
||||
float t = timesteps[step];
|
||||
float s = timesteps[step + 1];
|
||||
|
||||
if (params.algorithm == DIFFUSION_ALG_ORIGIN) {
|
||||
float p_transfer = (step < params.steps - 1) ? (1.0f - s / t) : 1.0f;
|
||||
|
||||
for (int32_t pos : mask_positions) {
|
||||
if (std::uniform_real_distribution<float>(0.0f, 1.0f)(rng) < p_transfer) {
|
||||
const float * pos_logits = get_logits_for_pos(pos);
|
||||
for (int32_t token_id = 0; token_id < n_vocab; token_id++) {
|
||||
candidates[token_id].id = token_id;
|
||||
candidates[token_id].logit = pos_logits[token_id];
|
||||
candidates[token_id].p = 0.0f;
|
||||
}
|
||||
|
||||
llama_token_data_array cur_p = {
|
||||
/* .data = */ candidates.data(),
|
||||
/* .size = */ (size_t) n_vocab, // Reset size to full vocab
|
||||
/* .selected = */ -1,
|
||||
/* .sorted = */ false,
|
||||
};
|
||||
|
||||
llama_sampler_apply(sampler, &cur_p);
|
||||
output_tokens[pos] = cur_p.data[cur_p.selected].id;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
std::vector<std::pair<float, int32_t>> confidences;
|
||||
std::vector<llama_token> sampled_tokens(mask_positions.size());
|
||||
|
||||
for (size_t i = 0; i < mask_positions.size(); i++) {
|
||||
int32_t pos = mask_positions[i];
|
||||
const float * pos_logits = get_logits_for_pos(pos);
|
||||
|
||||
for (int32_t token_id = 0; token_id < n_vocab; token_id++) {
|
||||
candidates[token_id].logit = pos_logits[token_id];
|
||||
candidates[token_id].p = 0.0f;
|
||||
candidates[token_id].id = token_id;
|
||||
}
|
||||
|
||||
llama_token_data_array cur_p = {
|
||||
/* .data = */ candidates.data(),
|
||||
/* .size = */ candidates.size(),
|
||||
/* .selected = */ -1,
|
||||
/* .sorted = */ false,
|
||||
};
|
||||
|
||||
llama_sampler_apply(sampler, &cur_p);
|
||||
|
||||
llama_token sampled_token = cur_p.data[cur_p.selected].id;
|
||||
|
||||
float confidence = 0.0f;
|
||||
if (params.algorithm == DIFFUSION_ALG_ENTROPY) {
|
||||
const float epsilon = 1e-10f;
|
||||
for (size_t j = 0; j < cur_p.size; j++) {
|
||||
float prob = cur_p.data[j].p;
|
||||
confidence += prob * logf(prob + epsilon);
|
||||
}
|
||||
} else if (params.algorithm == DIFFUSION_ALG_TOPK_MARGIN) {
|
||||
confidence = cur_p.data[0].p - cur_p.data[1].p;
|
||||
} else {
|
||||
confidence = cur_p.data[cur_p.selected].p;
|
||||
}
|
||||
|
||||
sampled_tokens[i] = sampled_token;
|
||||
confidences.emplace_back(confidence, i);
|
||||
}
|
||||
|
||||
int32_t num_transfer =
|
||||
(step < params.steps - 1) ? (int32_t) (mask_positions.size() * (1.0f - s / t)) : mask_positions.size();
|
||||
|
||||
if (num_transfer > 0) {
|
||||
if (params.alg_temp == 0.0f) {
|
||||
std::partial_sort(confidences.begin(), confidences.begin() + num_transfer, confidences.end(),
|
||||
[](const std::pair<float, int32_t> & a, const std::pair<float, int32_t> & b) {
|
||||
if (a.first != b.first) {
|
||||
return a.first > b.first;
|
||||
}
|
||||
return a.second < b.second;
|
||||
});
|
||||
} else {
|
||||
conf_candidates.clear();
|
||||
|
||||
for (int32_t pos = 0; pos < max_length; pos++) {
|
||||
float conf_logit = -std::numeric_limits<float>::infinity();
|
||||
|
||||
auto it = std::find(mask_positions.begin(), mask_positions.end(), pos);
|
||||
if (it != mask_positions.end()) {
|
||||
size_t mask_idx = std::distance(mask_positions.begin(), it);
|
||||
conf_logit = confidences[mask_idx].first / params.alg_temp; // Apply temperature scaling
|
||||
}
|
||||
|
||||
conf_candidates.emplace_back(llama_token_data{ pos, conf_logit, 0.0f });
|
||||
}
|
||||
|
||||
llama_token_data_array conf_array = {
|
||||
/* .data = */ conf_candidates.data(),
|
||||
/* .size = */ conf_candidates.size(),
|
||||
/* .selected = */ -1,
|
||||
/* .sorted = */ false,
|
||||
};
|
||||
|
||||
for (int32_t i = 0; i < num_transfer; i++) {
|
||||
// Apply distribution sampler to get selected index
|
||||
llama_sampler_apply(dist_sampler, &conf_array);
|
||||
int selected_idx = conf_array.selected;
|
||||
confidences[i].second = conf_candidates[selected_idx].id;
|
||||
|
||||
conf_candidates[selected_idx].p = 0.0f;
|
||||
conf_array.selected = -1;
|
||||
}
|
||||
}
|
||||
|
||||
if (params.alg_temp == 0.0f) {
|
||||
// Deterministic - use confidence order
|
||||
for (int32_t i = 0; i < num_transfer; i++) {
|
||||
int32_t mask_idx = confidences[i].second;
|
||||
int32_t pos = mask_positions[mask_idx];
|
||||
llama_token token = sampled_tokens[mask_idx];
|
||||
output_tokens[pos] = token;
|
||||
}
|
||||
} else {
|
||||
for (int32_t i = 0; i < num_transfer; i++) {
|
||||
int32_t pos = confidences[i].second;
|
||||
auto it = std::find(mask_positions.begin(), mask_positions.end(), pos);
|
||||
if (it != mask_positions.end()) {
|
||||
int32_t mask_idx = std::distance(mask_positions.begin(), it);
|
||||
output_tokens[pos] = sampled_tokens[mask_idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
int64_t time_end_sampling = ggml_time_us();
|
||||
total_sampling_time += time_end_sampling - time_start_sampling;
|
||||
}
|
||||
int64_t time_end = ggml_time_us();
|
||||
total_time += time_end - time_start;
|
||||
|
||||
LOG_INF("\ntotal time: %0.2fms, time per step: %0.2fms, sampling time per step: %0.2fms\n",
|
||||
total_time / 1000.0, total_time / 1000.0 / params.steps, total_sampling_time / 1000.0 / params.steps);
|
||||
|
||||
|
||||
llama_batch_free(batch);
|
||||
llama_sampler_free(sampler);
|
||||
llama_sampler_free(dist_sampler);
|
||||
|
||||
n_generated = max_length;
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
static std::string format_input_text(const std::string & prompt, bool use_chat_template, llama_model * model) {
|
||||
if (!use_chat_template) {
|
||||
return prompt;
|
||||
}
|
||||
|
||||
auto chat_templates = common_chat_templates_init(model, "");
|
||||
|
||||
common_chat_templates_inputs inputs;
|
||||
common_chat_msg user_msg;
|
||||
user_msg.role = "user";
|
||||
user_msg.content = prompt;
|
||||
inputs.add_generation_prompt = true;
|
||||
inputs.messages.push_back(user_msg);
|
||||
|
||||
auto result = common_chat_templates_apply(chat_templates.get(), inputs);
|
||||
|
||||
return result.prompt;
|
||||
}
|
||||
|
||||
struct callback_data {
|
||||
const common_params_diffusion * diff_params;
|
||||
const llama_vocab * vocab;
|
||||
int32_t n_input;
|
||||
};
|
||||
|
||||
static bool diffusion_step_callback(int32_t step,
|
||||
int32_t total_steps,
|
||||
const llama_token * tokens,
|
||||
int32_t n_tokens,
|
||||
void * user_data) {
|
||||
(void)user_data;
|
||||
|
||||
callback_data * data = static_cast<callback_data *>(user_data);
|
||||
|
||||
auto print_progress_bar = [](int32_t step, int32_t total_steps) {
|
||||
int progress_percent = (step * 100) / total_steps;
|
||||
int progress_bars = (step * 50) / total_steps;
|
||||
LOG_INF("\rdiffusion step: %d/%d [%s%s] %d%%",
|
||||
step,
|
||||
total_steps,
|
||||
std::string(progress_bars, '=').c_str(),
|
||||
std::string(50 - progress_bars, ' ').c_str(),
|
||||
progress_percent);
|
||||
};
|
||||
|
||||
if (data->diff_params->visual_mode) {
|
||||
// Visual mode: clear
|
||||
LOG_INF("\033[2J\033[H"); // Clear screen and move cursor to top-left
|
||||
|
||||
print_progress_bar(step, total_steps);
|
||||
|
||||
LOG_INF("\n");
|
||||
|
||||
std::string current_text = " ";
|
||||
|
||||
for (int32_t i = data->n_input; i < n_tokens; i++) {
|
||||
std::string token_str;
|
||||
if (tokens[i] != llama_vocab_mask(data->vocab)) {
|
||||
char piece[256];
|
||||
int n_chars = llama_token_to_piece(data->vocab, tokens[i], piece, sizeof(piece), 0, false);
|
||||
if (n_chars > 0) {
|
||||
piece[n_chars] = '\0';
|
||||
token_str = piece;
|
||||
}
|
||||
} else {
|
||||
token_str = " ";
|
||||
}
|
||||
|
||||
current_text += token_str;
|
||||
}
|
||||
|
||||
LOG_INF("%s\n", current_text.c_str());
|
||||
} else {
|
||||
print_progress_bar(step, total_steps);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
ggml_time_init();
|
||||
|
||||
common_params params;
|
||||
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_DIFFUSION)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
const char * alg_names[] = { "ORIGIN", "MASKGIT_PLUS", "TOPK_MARGIN", "ENTROPY" };
|
||||
const char * alg_name = (params.diffusion.algorithm >= 0 && params.diffusion.algorithm <= 3) ?
|
||||
alg_names[params.diffusion.algorithm] :
|
||||
"UNKNOWN";
|
||||
|
||||
common_init();
|
||||
llama_backend_init();
|
||||
|
||||
llama_model_params model_params = llama_model_default_params();
|
||||
model_params.n_gpu_layers = params.n_gpu_layers;
|
||||
model_params.devices = params.devices.data();
|
||||
model_params.use_mmap = params.use_mmap;
|
||||
model_params.use_mlock = params.use_mlock;
|
||||
model_params.check_tensors = params.check_tensors;
|
||||
|
||||
llama_model * model = llama_model_load_from_file(params.model.path.c_str(), model_params);
|
||||
if (!model) {
|
||||
LOG_ERR("error: failed to load model '%s'\n", params.model.path.c_str());
|
||||
return 1;
|
||||
}
|
||||
|
||||
llama_context_params ctx_params = llama_context_default_params();
|
||||
ctx_params.n_ctx = params.n_ctx;
|
||||
ctx_params.n_batch = params.n_batch;
|
||||
ctx_params.n_ubatch = params.n_ubatch;
|
||||
ctx_params.flash_attn = params.flash_attn;
|
||||
ctx_params.no_perf = params.no_perf;
|
||||
ctx_params.type_k = params.cache_type_k;
|
||||
ctx_params.type_v = params.cache_type_v;
|
||||
|
||||
llama_context * ctx = llama_init_from_model(model, ctx_params);
|
||||
if (!ctx) {
|
||||
LOG_ERR("error: failed to create context\n");
|
||||
llama_model_free(model);
|
||||
return 1;
|
||||
}
|
||||
|
||||
llama_set_n_threads(ctx, params.cpuparams.n_threads, params.cpuparams_batch.n_threads);
|
||||
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
std::string formatted_prompt = format_input_text(params.prompt, params.enable_chat_template, model);
|
||||
|
||||
std::vector<llama_token> input_tokens = common_tokenize(vocab, formatted_prompt,
|
||||
/*add special tokens*/ true,
|
||||
/*parse special*/ true);
|
||||
int n_input = input_tokens.size();
|
||||
|
||||
if (n_input >= params.n_ctx) {
|
||||
LOG_ERR("error: input too long (%d tokens), max context is %d\n", n_input, params.n_ctx);
|
||||
llama_free(ctx);
|
||||
llama_model_free(model);
|
||||
return 1;
|
||||
}
|
||||
|
||||
struct diffusion_params ldiff_params = diffusion_default_params();
|
||||
ldiff_params.steps = params.diffusion.steps;
|
||||
ldiff_params.eps = params.diffusion.eps;
|
||||
ldiff_params.temperature = params.sampling.temp;
|
||||
ldiff_params.top_p = params.sampling.top_p;
|
||||
ldiff_params.top_k = params.sampling.top_k;
|
||||
ldiff_params.algorithm = static_cast<enum diffusion_alg>(params.diffusion.algorithm);
|
||||
ldiff_params.alg_temp = params.diffusion.alg_temp;
|
||||
ldiff_params.seed = params.sampling.seed;
|
||||
|
||||
llama_token mask_token_id = llama_vocab_mask(vocab);
|
||||
GGML_ASSERT(mask_token_id != LLAMA_TOKEN_NULL);
|
||||
|
||||
LOG_INF("diffusion_params: - %-25s llama_token = %d\n", "mask_token_id", mask_token_id);
|
||||
LOG_INF("diffusion_params: - %-25s u32 = %d\n", "steps", params.diffusion.steps);
|
||||
LOG_INF("diffusion_params: - %-25s f32 = %.6f\n", "eps", params.diffusion.eps);
|
||||
LOG_INF("diffusion_params: - %-25s u32 = %d (%s)\n", "algorithm", params.diffusion.algorithm,
|
||||
alg_name);
|
||||
LOG_INF("diffusion_params: - %-25s f32 = %.3f\n", "alg_temp", params.diffusion.alg_temp);
|
||||
|
||||
ldiff_params.mask_token_id = mask_token_id;
|
||||
|
||||
callback_data cb_data = { ¶ms.diffusion, vocab, n_input };
|
||||
|
||||
ldiff_params.step_callback = diffusion_step_callback;
|
||||
ldiff_params.step_callback_user_data = &cb_data;
|
||||
|
||||
int32_t n_generated = 0;
|
||||
|
||||
std::vector<llama_token> output_tokens(params.n_ubatch);
|
||||
diffusion_generate(ctx, input_tokens.data(), output_tokens.data(), n_input, params.n_ubatch,
|
||||
ldiff_params, n_generated);
|
||||
|
||||
if (n_generated > 0) {
|
||||
if (params.diffusion.visual_mode) {
|
||||
//clear screen and move cursor to top-left
|
||||
LOG_INF("\033[2J\033[H");
|
||||
}
|
||||
output_tokens.erase(output_tokens.begin(), output_tokens.begin() + n_input);
|
||||
std::string output_data = common_detokenize(vocab, output_tokens, false);
|
||||
LOG_INF("\n%s\n", output_data.c_str());
|
||||
} else {
|
||||
LOG_INF("Error: diffusion generation failed\n");
|
||||
}
|
||||
|
||||
llama_free(ctx);
|
||||
llama_model_free(model);
|
||||
llama_backend_free();
|
||||
|
||||
return 0;
|
||||
}
|
|
@ -107,7 +107,7 @@ int main(int argc, char ** argv) {
|
|||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
||||
const int n_ctx_train = llama_model_n_ctx_train(model);
|
||||
const int n_ctx = llama_n_ctx(ctx);
|
||||
const int n_ctx = llama_n_ctx(ctx);
|
||||
|
||||
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
|
||||
|
||||
|
|
19
ggml/include/ggml-webgpu.h
Normal file
19
ggml/include/ggml-webgpu.h
Normal file
|
@ -0,0 +1,19 @@
|
|||
#pragma once
|
||||
|
||||
#include "ggml.h"
|
||||
#include "ggml-backend.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
#define GGML_WEBGPU_NAME "WebGPU"
|
||||
|
||||
// Needed for examples in ggml
|
||||
GGML_BACKEND_API ggml_backend_t ggml_backend_webgpu_init(void);
|
||||
|
||||
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_webgpu_reg(void);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
|
@ -45,6 +45,10 @@
|
|||
#include "ggml-vulkan.h"
|
||||
#endif
|
||||
|
||||
#ifdef GGML_USE_WEBGPU
|
||||
#include "ggml-webgpu.h"
|
||||
#endif
|
||||
|
||||
#ifdef GGML_USE_OPENCL
|
||||
#include "ggml-opencl.h"
|
||||
#endif
|
||||
|
@ -173,6 +177,9 @@ struct ggml_backend_registry {
|
|||
#ifdef GGML_USE_VULKAN
|
||||
register_backend(ggml_backend_vk_reg());
|
||||
#endif
|
||||
#ifdef GGML_USE_WEBGPU
|
||||
register_backend(ggml_backend_webgpu_reg());
|
||||
#endif
|
||||
#ifdef GGML_USE_OPENCL
|
||||
register_backend(ggml_backend_opencl_reg());
|
||||
#endif
|
||||
|
|
|
@ -4015,6 +4015,9 @@ static void ggml_compute_forward_rms_norm_f32(
|
|||
|
||||
const float scale = 1.0f/sqrtf(mean + eps);
|
||||
|
||||
// if you hit this, likely you got an inf somewhere earlier
|
||||
assert(scale > 0.0f);
|
||||
|
||||
ggml_vec_scale_f32(ne00, y, scale);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -221,6 +221,9 @@ void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * G
|
|||
for (int i = np; i < n; ++i) {
|
||||
sumf += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[i])*GGML_CPU_FP16_TO_FP32(y[i]));
|
||||
}
|
||||
|
||||
// if you hit this, you are likely running outside the FP range
|
||||
assert(!isnan(sumf) && !isinf(sumf));
|
||||
#else
|
||||
for (int i = 0; i < n; ++i) {
|
||||
sumf += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[i])*GGML_CPU_FP16_TO_FP32(y[i]));
|
||||
|
|
|
@ -33,8 +33,10 @@ typedef void (* fattn_kernel_t)(
|
|||
const int ne13,
|
||||
const int ne31,
|
||||
const int ne32,
|
||||
const int ne33,
|
||||
const int nb31,
|
||||
const int nb32,
|
||||
const int nb33,
|
||||
const int nb01,
|
||||
const int nb02,
|
||||
const int nb03,
|
||||
|
@ -521,7 +523,7 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
|
|||
template<int D, int ncols1, int ncols2> // D == head size
|
||||
__launch_bounds__(D, 1)
|
||||
static __global__ void flash_attn_stream_k_fixup(
|
||||
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) {
|
||||
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, const int ne11) {
|
||||
constexpr int ncols = ncols1*ncols2;
|
||||
|
||||
const int bidx0 = blockIdx.x;
|
||||
|
@ -535,8 +537,8 @@ static __global__ void flash_attn_stream_k_fixup(
|
|||
const int iter_k = ne11 / FATTN_KQ_STRIDE;
|
||||
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
|
||||
|
||||
const int kbc0 = (bidx0 + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
|
||||
const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
|
||||
const int kbc0 = (bidx0 + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
||||
const int kbc0_stop = (bidx0 + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
||||
|
||||
const bool did_not_have_any_data = kbc0 == kbc0_stop;
|
||||
const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
|
||||
|
@ -545,14 +547,15 @@ static __global__ void flash_attn_stream_k_fixup(
|
|||
return;
|
||||
}
|
||||
|
||||
const int channel = kbc0 / (iter_k*iter_j);
|
||||
const int jt = (kbc0 - channel*iter_k*iter_j) / iter_k;
|
||||
const int sequence = kbc0 / (iter_k*iter_j*(ne02/ncols2));
|
||||
const int head = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
|
||||
const int jt = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
|
||||
|
||||
if (jt*ncols1 + j >= ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
dst += jt*ne02*(ncols1*D) + channel*(ncols2*D) + (j*ne02 + c)*D + tid;
|
||||
dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + head*(ncols2*D) + (j*ne02 + c)*D + tid;
|
||||
|
||||
// Load the partial result that needs a fixup:
|
||||
float dst_val = 0.0f;
|
||||
|
@ -571,7 +574,7 @@ static __global__ void flash_attn_stream_k_fixup(
|
|||
int bidx = bidx0 - 1;
|
||||
int kbc_stop = kbc0;
|
||||
while(true) {
|
||||
const int kbc = bidx*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
|
||||
const int kbc = bidx*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
||||
if (kbc == kbc_stop) { // Did not have any data.
|
||||
bidx--;
|
||||
kbc_stop = kbc;
|
||||
|
@ -617,16 +620,31 @@ static __global__ void flash_attn_combine_results(
|
|||
const float2 * __restrict__ VKQ_meta,
|
||||
float * __restrict__ dst,
|
||||
const int parallel_blocks) {
|
||||
VKQ_parts += parallel_blocks*D * gridDim.z*blockIdx.x;
|
||||
VKQ_meta += parallel_blocks * gridDim.z*blockIdx.x;
|
||||
dst += D * gridDim.z*blockIdx.x;
|
||||
// Dimension 0: threadIdx.x
|
||||
// Dimension 1: blockIdx.x
|
||||
// Dimension 2: blockIdx.y
|
||||
// Dimension 3: blockIdx.z
|
||||
// Memory layout is permuted with [0, 2, 1, 3]
|
||||
|
||||
const int ne01 = gridDim.x;
|
||||
const int ne02 = gridDim.y;
|
||||
|
||||
const int col = blockIdx.x;
|
||||
const int head = blockIdx.y;
|
||||
const int sequence = blockIdx.z;
|
||||
|
||||
const int j_dst_unrolled = (sequence*ne01 + col)*ne02 + head;
|
||||
|
||||
VKQ_parts += j_dst_unrolled * parallel_blocks*D;
|
||||
VKQ_meta += j_dst_unrolled * parallel_blocks;
|
||||
dst += j_dst_unrolled * D;
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
__builtin_assume(tid < D);
|
||||
|
||||
extern __shared__ float2 meta[];
|
||||
for (int i = tid; i < 2*parallel_blocks; i += D) {
|
||||
((float *) meta)[i] = ((const float *)VKQ_meta) [blockIdx.z*(2*parallel_blocks) + i];
|
||||
((float *) meta)[i] = ((const float *)VKQ_meta) [i];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
@ -644,11 +662,11 @@ static __global__ void flash_attn_combine_results(
|
|||
const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
|
||||
*((uint32_t *) &KQ_max_scale) &= ftz_mask;
|
||||
|
||||
VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.z*D + blockIdx.z*D + tid];
|
||||
VKQ_numerator += KQ_max_scale * VKQ_parts[l*D + tid];
|
||||
VKQ_denominator += KQ_max_scale * meta[l].y;
|
||||
}
|
||||
|
||||
dst[blockIdx.z*D + tid] = VKQ_numerator / VKQ_denominator;
|
||||
dst[tid] = VKQ_numerator / VKQ_denominator;
|
||||
}
|
||||
|
||||
[[noreturn]]
|
||||
|
@ -705,8 +723,6 @@ void launch_fattn(
|
|||
|
||||
GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
|
||||
|
||||
GGML_ASSERT(Q->ne[3] == 1);
|
||||
|
||||
ggml_cuda_pool & pool = ctx.pool();
|
||||
cudaStream_t main_stream = ctx.stream();
|
||||
const int id = ggml_cuda_get_device();
|
||||
|
@ -853,8 +869,8 @@ void launch_fattn(
|
|||
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||
mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0,
|
||||
mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0,
|
||||
mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0,
|
||||
mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0, mask ? mask->nb[3] : 0,
|
||||
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||
nb11, nb12, nb13,
|
||||
nb21, nb22, nb23,
|
||||
|
@ -869,11 +885,11 @@ void launch_fattn(
|
|||
|
||||
flash_attn_stream_k_fixup<DV, ncols1, ncols2>
|
||||
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
|
||||
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]);
|
||||
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1]);
|
||||
}
|
||||
} else if (parallel_blocks > 1) {
|
||||
const dim3 block_dim_combine(DV, 1, 1);
|
||||
const dim3 blocks_num_combine(Q->ne[1], 1, blocks_num.z);
|
||||
const dim3 blocks_num_combine(Q->ne[1], Q->ne[2], Q->ne[3]);
|
||||
const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2);
|
||||
|
||||
flash_attn_combine_results<DV>
|
||||
|
|
|
@ -1224,8 +1224,10 @@ static __global__ void flash_attn_ext_f16(
|
|||
const int ne13,
|
||||
const int ne31,
|
||||
const int ne32,
|
||||
const int ne33,
|
||||
const int nb31,
|
||||
const int nb32,
|
||||
const int nb33,
|
||||
const int nb01,
|
||||
const int nb02,
|
||||
const int nb03,
|
||||
|
@ -1274,8 +1276,8 @@ static __global__ void flash_attn_ext_f16(
|
|||
constexpr int kb_niter = FATTN_KQ_STRIDE / c::nbatch_fa; // Number of kernel iterations per assigned KQ slice.
|
||||
|
||||
// kbc == k block continuous, current index in continuous ijk space.
|
||||
int kbc = (blockIdx.x + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
|
||||
const int kbc_stop = (blockIdx.x + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
|
||||
int kbc = (blockIdx.x + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
||||
const int kbc_stop = (blockIdx.x + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
||||
|
||||
// If the seams of 2 CUDA blocks fall within an output tile their results need to be combined.
|
||||
// For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup).
|
||||
|
@ -1285,18 +1287,19 @@ static __global__ void flash_attn_ext_f16(
|
|||
int kb0_start = kbc % iter_k;
|
||||
int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc);
|
||||
while (kbc < kbc_stop && kb0_stop == iter_k) {
|
||||
const int channel = kbc / (iter_k*iter_j);
|
||||
const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile.
|
||||
const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
|
||||
const int head = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
|
||||
const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
|
||||
|
||||
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
|
||||
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
|
||||
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*(head*ncols2));
|
||||
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head*ncols2 / gqa_ratio));
|
||||
const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
|
||||
(const half2 *) (mask + nb32*(channel % ne32) + nb31*jt*ncols1);
|
||||
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
|
||||
(const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1);
|
||||
float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head*ncols2) * (DV/2);
|
||||
|
||||
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
|
||||
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head*ncols2 / gqa_ratio));
|
||||
|
||||
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
|
||||
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f;
|
||||
|
||||
const int kb0_start_kernel = kb0_start * kb_niter;
|
||||
const int kb0_stop_kernel = kb0_stop * kb_niter;
|
||||
|
@ -1325,18 +1328,19 @@ static __global__ void flash_attn_ext_f16(
|
|||
return;
|
||||
}
|
||||
|
||||
const int channel = kbc / (iter_k*iter_j);
|
||||
const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile.
|
||||
const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
|
||||
const int head = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
|
||||
const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
|
||||
|
||||
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
|
||||
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
|
||||
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*(head*ncols2));
|
||||
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head*ncols2 / gqa_ratio));
|
||||
const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
|
||||
(const half2 *) (mask + nb32*(channel % ne32) + nb31*jt*ncols1);
|
||||
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
|
||||
(const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1);
|
||||
float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head*ncols2) * (DV/2);
|
||||
|
||||
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
|
||||
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head*ncols2 / gqa_ratio));
|
||||
|
||||
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
|
||||
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f;
|
||||
|
||||
const int kb0_start_kernel = kb0_start * kb_niter;
|
||||
const int kb0_stop_kernel = kb0_stop * kb_niter;
|
||||
|
|
|
@ -31,8 +31,10 @@ static __global__ void flash_attn_tile_ext_f16(
|
|||
const int ne13,
|
||||
const int ne31,
|
||||
const int ne32,
|
||||
const int ne33,
|
||||
const int nb31,
|
||||
const int nb32,
|
||||
const int nb33,
|
||||
const int nb01,
|
||||
const int nb02,
|
||||
const int nb03,
|
||||
|
@ -62,15 +64,17 @@ static __global__ void flash_attn_tile_ext_f16(
|
|||
|
||||
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
|
||||
|
||||
const int sequence = blockIdx.z / ne02;
|
||||
const int head = blockIdx.z - sequence*ne02;
|
||||
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
||||
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0);
|
||||
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio));
|
||||
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
|
||||
const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
|
||||
const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0);
|
||||
const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
|
||||
const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
|
||||
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
|
||||
|
||||
const int stride_KV2 = nb11 / sizeof(half2);
|
||||
|
||||
const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
|
||||
const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
|
||||
const half slopeh = __float2half(slopef);
|
||||
|
||||
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
|
||||
|
@ -255,6 +259,8 @@ static __global__ void flash_attn_tile_ext_f16(
|
|||
__syncthreads();
|
||||
}
|
||||
|
||||
float2 * dst2 = (float2 *) dst;
|
||||
|
||||
#pragma unroll
|
||||
for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
|
||||
const int j_VKQ = j_VKQ_0 + threadIdx.y;
|
||||
|
@ -266,21 +272,21 @@ static __global__ void flash_attn_tile_ext_f16(
|
|||
half kqsum_j = __low2half(kqsum[j_VKQ_0/nwarps]) + __high2half(kqsum[j_VKQ_0/nwarps]);
|
||||
kqsum_j = warp_reduce_sum((float)kqsum_j);
|
||||
|
||||
#pragma unroll
|
||||
for (int i00 = 0; i00 < D; i00 += 2*WARP_SIZE) {
|
||||
const int i0 = i00 + 2*threadIdx.x;
|
||||
const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
|
||||
|
||||
half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)];
|
||||
#pragma unroll
|
||||
for (int i00 = 0; i00 < D/2; i00 += WARP_SIZE) {
|
||||
const int i0 = i00 + threadIdx.x;
|
||||
|
||||
half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/WARP_SIZE];
|
||||
if (gridDim.y == 1) {
|
||||
dst_val /= __half2half2(kqsum_j);
|
||||
}
|
||||
const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
|
||||
dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 0] = __low2float(dst_val);
|
||||
dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 1] = __high2float(dst_val);
|
||||
dst2[j_dst_unrolled*(D/2) + i0] = __half22float2(dst_val);
|
||||
}
|
||||
|
||||
if (gridDim.y != 1 && threadIdx.x == 0) {
|
||||
dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
|
||||
dst_meta[j_dst_unrolled] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
|
||||
}
|
||||
}
|
||||
#else
|
||||
|
@ -290,8 +296,8 @@ static __global__ void flash_attn_tile_ext_f16(
|
|||
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
||||
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
|
||||
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
|
||||
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
|
||||
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
||||
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
|
||||
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
||||
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
|
||||
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
|
||||
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
|
||||
|
|
|
@ -31,8 +31,10 @@ static __global__ void flash_attn_tile_ext_f32(
|
|||
const int ne13,
|
||||
const int ne31,
|
||||
const int ne32,
|
||||
const int ne33,
|
||||
const int nb31,
|
||||
const int nb32,
|
||||
const int nb33,
|
||||
const int nb01,
|
||||
const int nb02,
|
||||
const int nb03,
|
||||
|
@ -74,15 +76,17 @@ static __global__ void flash_attn_tile_ext_f32(
|
|||
|
||||
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
|
||||
|
||||
const int sequence = blockIdx.z / ne02;
|
||||
const int head = blockIdx.z - sequence*ne02;
|
||||
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
||||
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0);
|
||||
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio));
|
||||
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
|
||||
const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
|
||||
const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0);
|
||||
const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
|
||||
const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
|
||||
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
|
||||
|
||||
const int stride_KV2 = nb11 / sizeof(half2);
|
||||
|
||||
const float slope = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
|
||||
const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
|
||||
|
||||
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
|
||||
|
||||
|
@ -265,6 +269,8 @@ static __global__ void flash_attn_tile_ext_f32(
|
|||
__syncthreads();
|
||||
}
|
||||
|
||||
float2 * dst2 = (float2 *) dst;
|
||||
|
||||
#pragma unroll
|
||||
for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
|
||||
const int j_VKQ = j_VKQ_0 + threadIdx.y;
|
||||
|
@ -276,22 +282,22 @@ static __global__ void flash_attn_tile_ext_f32(
|
|||
float kqsum_j = kqsum[j_VKQ_0/nwarps];
|
||||
kqsum_j = warp_reduce_sum(kqsum_j);
|
||||
|
||||
#pragma unroll
|
||||
for (int i00 = 0; i00 < D; i00 += 2*WARP_SIZE) {
|
||||
const int i0 = i00 + 2*threadIdx.x;
|
||||
const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
|
||||
|
||||
float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)];
|
||||
#pragma unroll
|
||||
for (int i00 = 0; i00 < D/2; i00 += WARP_SIZE) {
|
||||
const int i0 = i00 + threadIdx.x;
|
||||
|
||||
float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/WARP_SIZE];
|
||||
if (gridDim.y == 1) {
|
||||
dst_val.x /= kqsum_j;
|
||||
dst_val.y /= kqsum_j;
|
||||
}
|
||||
const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
|
||||
dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 0] = dst_val.x;
|
||||
dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 1] = dst_val.y;
|
||||
dst2[j_dst_unrolled*(D/2) + i0] = dst_val;
|
||||
}
|
||||
|
||||
if (gridDim.y != 1 && threadIdx.x == 0) {
|
||||
dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
|
||||
dst_meta[j_dst_unrolled] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
|
||||
}
|
||||
}
|
||||
#else
|
||||
|
|
|
@ -28,8 +28,10 @@ static __global__ void flash_attn_vec_ext_f16(
|
|||
const int ne13,
|
||||
const int ne31,
|
||||
const int ne32,
|
||||
const int ne33,
|
||||
const int nb31,
|
||||
const int nb32,
|
||||
const int nb33,
|
||||
const int nb01,
|
||||
const int nb02,
|
||||
const int nb03,
|
||||
|
@ -65,14 +67,16 @@ static __global__ void flash_attn_vec_ext_f16(
|
|||
|
||||
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
|
||||
|
||||
const int sequence = blockIdx.z / ne02;
|
||||
const int head = blockIdx.z - sequence*ne02;
|
||||
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
||||
Q += nb02* blockIdx.z + nb01*ic0;
|
||||
K += nb12*(blockIdx.z / gqa_ratio);
|
||||
V += nb22*(blockIdx.z / gqa_ratio);
|
||||
Q += nb03*sequence + nb02* head + nb01*ic0;
|
||||
K += nb13*sequence + nb12*(head / gqa_ratio);
|
||||
V += nb23*sequence + nb22*(head / gqa_ratio);
|
||||
|
||||
const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
|
||||
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
|
||||
|
||||
const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
|
||||
const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
|
||||
const half slopeh = __float2half(slopef);
|
||||
|
||||
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
|
||||
|
@ -330,12 +334,11 @@ static __global__ void flash_attn_vec_ext_f16(
|
|||
if (gridDim.y == 1) {
|
||||
dst_val /= kqsum[j_VKQ];
|
||||
}
|
||||
const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
|
||||
dst[j_dst*D*gridDim.z + D*blockIdx.z + tid] = dst_val;
|
||||
dst[(((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*D + tid] = dst_val;
|
||||
}
|
||||
|
||||
if (gridDim.y != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
|
||||
dst_meta[((ic0 + tid)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
|
||||
dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
|
||||
|
@ -344,8 +347,8 @@ static __global__ void flash_attn_vec_ext_f16(
|
|||
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
||||
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
|
||||
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
|
||||
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
|
||||
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
||||
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne32);
|
||||
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
||||
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
|
||||
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
|
||||
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
|
||||
|
|
|
@ -28,8 +28,10 @@ static __global__ void flash_attn_vec_ext_f32(
|
|||
const int ne13,
|
||||
const int ne31,
|
||||
const int ne32,
|
||||
const int ne33,
|
||||
const int nb31,
|
||||
const int nb32,
|
||||
const int nb33,
|
||||
const int nb01,
|
||||
const int nb02,
|
||||
const int nb03,
|
||||
|
@ -53,8 +55,8 @@ static __global__ void flash_attn_vec_ext_f32(
|
|||
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
||||
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
|
||||
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
|
||||
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
|
||||
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
||||
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
|
||||
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
||||
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
|
||||
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
|
||||
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
|
||||
|
@ -77,14 +79,16 @@ static __global__ void flash_attn_vec_ext_f32(
|
|||
|
||||
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
|
||||
|
||||
const int sequence = blockIdx.z / ne02;
|
||||
const int head = blockIdx.z - sequence*ne02;
|
||||
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
||||
Q += nb02* blockIdx.z + nb01*ic0;
|
||||
K += nb12*(blockIdx.z / gqa_ratio);
|
||||
V += nb22*(blockIdx.z / gqa_ratio); // K and V have same shape
|
||||
Q += nb03*sequence + nb02* head + nb01*ic0;
|
||||
K += nb13*sequence + nb12*(head / gqa_ratio);
|
||||
V += nb23*sequence + nb22*(head / gqa_ratio);
|
||||
|
||||
const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
|
||||
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
|
||||
|
||||
const float slope = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
|
||||
const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
|
||||
|
||||
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
|
||||
constexpr int nwarps = D / WARP_SIZE;
|
||||
|
@ -326,12 +330,11 @@ static __global__ void flash_attn_vec_ext_f32(
|
|||
if (gridDim.y == 1) {
|
||||
dst_val /= kqsum[j_VKQ];
|
||||
}
|
||||
const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
|
||||
dst[j_dst*D*gridDim.z + D*blockIdx.z + tid] = dst_val;
|
||||
dst[(((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*D + tid] = dst_val;
|
||||
}
|
||||
|
||||
if (gridDim.y != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
|
||||
dst_meta[((ic0 + tid)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
|
||||
dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
|
||||
|
@ -340,8 +343,8 @@ static __global__ void flash_attn_vec_ext_f32(
|
|||
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
||||
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
|
||||
GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
|
||||
GGML_UNUSED(ne31); GGML_UNUSED(ne32);
|
||||
GGML_UNUSED(nb31); GGML_UNUSED(nb32);
|
||||
GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
|
||||
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33);
|
||||
GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
|
||||
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
|
||||
GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
|
||||
|
|
|
@ -47,8 +47,10 @@ static __global__ void flash_attn_ext_f16(
|
|||
const int ne13,
|
||||
const int ne31,
|
||||
const int ne32,
|
||||
const int ne33,
|
||||
const int nb31,
|
||||
const int nb32,
|
||||
const int nb33,
|
||||
const int nb01,
|
||||
const int nb02,
|
||||
const int nb03,
|
||||
|
@ -95,17 +97,19 @@ static __global__ void flash_attn_ext_f16(
|
|||
constexpr int kqs_padded = FATTN_KQ_STRIDE + 8;
|
||||
constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
|
||||
|
||||
const int sequence = blockIdx.z / ne02;
|
||||
const int head = blockIdx.z - sequence*ne02;
|
||||
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
||||
const float * Q_f = (const float *) (Q + nb02* blockIdx.z + nb01*ic0);
|
||||
const half * K_h = (const half *) (K + nb12*(blockIdx.z / gqa_ratio));
|
||||
const half * V_h = (const half *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
|
||||
const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
|
||||
const float * Q_f = (const float *) (Q + nb03* sequence + nb02* head + nb01*ic0);
|
||||
const half * K_h = (const half *) (K + nb13* sequence + nb12*(head / gqa_ratio));
|
||||
const half * V_h = (const half *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
|
||||
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
|
||||
const half2 * mask2 = (const half2 *) maskh;
|
||||
|
||||
const int stride_Q = nb01 / sizeof(float);
|
||||
const int stride_KV = nb11 / sizeof(half);
|
||||
|
||||
const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
|
||||
const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
|
||||
const half slopeh = __float2half(slopef);
|
||||
const half2 slope2 = make_half2(slopef, slopef);
|
||||
|
||||
|
@ -400,7 +404,6 @@ static __global__ void flash_attn_ext_f16(
|
|||
if (ic0 + j_VKQ >= ne01) {
|
||||
return;
|
||||
}
|
||||
const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
|
||||
|
||||
float KQ_rowsum_j;
|
||||
if (std::is_same<KQ_acc_t, float>::value) {
|
||||
|
@ -409,6 +412,8 @@ static __global__ void flash_attn_ext_f16(
|
|||
KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]);
|
||||
}
|
||||
|
||||
const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D; i0 += warp_size) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
|
@ -419,7 +424,7 @@ static __global__ void flash_attn_ext_f16(
|
|||
if (gridDim.y == 1) {
|
||||
dst_val /= KQ_rowsum_j;
|
||||
}
|
||||
dst[j_dst*gridDim.z*D + blockIdx.z*D + i] = dst_val;
|
||||
dst[j_dst_unrolled*D + i] = dst_val;
|
||||
}
|
||||
|
||||
if (gridDim.y == 1 || threadIdx.x != 0) {
|
||||
|
@ -433,7 +438,7 @@ static __global__ void flash_attn_ext_f16(
|
|||
dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]);
|
||||
}
|
||||
dst_meta_val.y = KQ_rowsum_j;
|
||||
dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = dst_meta_val;
|
||||
dst_meta[j_dst_unrolled] = dst_meta_val;
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
|
||||
|
@ -442,7 +447,8 @@ static __global__ void flash_attn_ext_f16(
|
|||
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
||||
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
|
||||
GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
|
||||
GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
||||
GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33); GGML_UNUSED(nb31);
|
||||
GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
||||
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
|
||||
GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
|
||||
GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3);
|
||||
|
|
|
@ -3418,12 +3418,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|||
if (op->src[0]->ne[0] == 192) {
|
||||
return false;
|
||||
}
|
||||
// TODO: support broadcast
|
||||
// note: this was initially implemented in https://github.com/ggml-org/llama.cpp/pull/14500, but
|
||||
// the interface of ggml_flash_attn_ext() changed in https://github.com/ggml-org/llama.cpp/pull/14505
|
||||
if (op->src[0]->ne[3] != 1) {
|
||||
return false;
|
||||
}
|
||||
if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) {
|
||||
return false;
|
||||
}
|
||||
|
@ -3436,6 +3430,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|||
if (op->src[0]->ne[0] == 256 && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) {
|
||||
return true;
|
||||
}
|
||||
if (op->src[3] && op->src[3]->ne[2] != 1) {
|
||||
return false;
|
||||
}
|
||||
return fp16_mma_available(ggml_cuda_info().devices[dev_ctx->device].cc) &&
|
||||
op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
|
||||
}
|
||||
|
|
54
ggml/src/ggml-webgpu/CMakeLists.txt
Normal file
54
ggml/src/ggml-webgpu/CMakeLists.txt
Normal file
|
@ -0,0 +1,54 @@
|
|||
cmake_minimum_required(VERSION 3.13)
|
||||
|
||||
find_package(Python3 REQUIRED)
|
||||
|
||||
# Shader locations
|
||||
set(SHADER_DIR "${CMAKE_CURRENT_SOURCE_DIR}/wgsl-shaders")
|
||||
set(SHADER_OUTPUT_DIR "${CMAKE_CURRENT_BINARY_DIR}/generated")
|
||||
set(SHADER_HEADER "${SHADER_OUTPUT_DIR}/ggml-wgsl-shaders.hpp")
|
||||
file(MAKE_DIRECTORY ${SHADER_OUTPUT_DIR})
|
||||
|
||||
message(STATUS "Shader output dir: ${SHADER_OUTPUT_DIR}")
|
||||
|
||||
# Find all WGSL files
|
||||
file(GLOB WGSL_SHADER_FILES "${SHADER_DIR}/*.wgsl")
|
||||
|
||||
# Generate the header using a Python script
|
||||
add_custom_command(
|
||||
OUTPUT ${SHADER_HEADER}
|
||||
COMMAND ${CMAKE_COMMAND} -E echo "Embedding WGSL shaders to ggml-wgsl-shaders.hpp"
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory ${SHADER_OUTPUT_DIR}
|
||||
COMMAND ${CMAKE_COMMAND} -E env PYTHONIOENCODING=utf-8
|
||||
${Python3_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/wgsl-shaders/embed_wgsl.py
|
||||
--input "${SHADER_DIR}"
|
||||
--output "${SHADER_HEADER}"
|
||||
DEPENDS ${WGSL_SHADER_FILES} ${CMAKE_CURRENT_SOURCE_DIR}/wgsl-shaders/embed_wgsl.py
|
||||
VERBATIM
|
||||
)
|
||||
|
||||
add_custom_target(generate_shaders DEPENDS ${SHADER_HEADER})
|
||||
|
||||
ggml_add_backend_library(ggml-webgpu
|
||||
ggml-webgpu.cpp
|
||||
${SHADER_HEADER}
|
||||
../../include/ggml-webgpu.h
|
||||
)
|
||||
|
||||
add_dependencies(ggml-webgpu generate_shaders)
|
||||
|
||||
if(EMSCRIPTEN)
|
||||
set(EMDAWNWEBGPU_DIR "" CACHE PATH "Path to emdawnwebgpu_pkg")
|
||||
|
||||
target_compile_options(ggml-webgpu PRIVATE "--use-port=${EMDAWNWEBGPU_DIR}/emdawnwebgpu.port.py")
|
||||
target_link_options(ggml-webgpu PRIVATE "--use-port=${EMDAWNWEBGPU_DIR}/emdawnwebgpu.port.py")
|
||||
else()
|
||||
find_package(Dawn REQUIRED)
|
||||
set(DawnWebGPU_TARGET dawn::webgpu_dawn)
|
||||
endif()
|
||||
|
||||
if (GGML_WEBGPU_DEBUG)
|
||||
target_compile_definitions(ggml-webgpu PRIVATE GGML_WEBGPU_DEBUG=1)
|
||||
endif()
|
||||
|
||||
target_include_directories(ggml-webgpu PRIVATE ${SHADER_OUTPUT_DIR})
|
||||
target_link_libraries(ggml-webgpu PRIVATE ${DawnWebGPU_TARGET})
|
907
ggml/src/ggml-webgpu/ggml-webgpu.cpp
Normal file
907
ggml/src/ggml-webgpu/ggml-webgpu.cpp
Normal file
|
@ -0,0 +1,907 @@
|
|||
#include "ggml-webgpu.h"
|
||||
|
||||
#include <webgpu/webgpu_cpp.h>
|
||||
|
||||
#include "ggml-impl.h"
|
||||
#include "ggml-backend-impl.h"
|
||||
|
||||
#include "ggml-wgsl-shaders.hpp"
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <mutex>
|
||||
#include <vector>
|
||||
|
||||
#ifdef GGML_WEBGPU_DEBUG
|
||||
#define WEBGPU_LOG_DEBUG(msg) std::cout << msg << std::endl
|
||||
#else
|
||||
#define WEBGPU_LOG_DEBUG(msg) ((void) 0)
|
||||
#endif // GGML_WEBGPU_DEBUG
|
||||
|
||||
/* Constants */
|
||||
|
||||
#define WEBGPU_MUL_MAT_WG_SIZE 64
|
||||
#define WEBGPU_MUL_MAT_PARAMS_SIZE (13 * sizeof(uint32_t)) // M, N, K, batch sizes, broadcasts
|
||||
#define WEBGPU_CPY_PARAMS_SIZE (15 * sizeof(uint32_t)) // strides and offsets
|
||||
#define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4
|
||||
|
||||
/* End Constants */
|
||||
|
||||
// This is a "fake" base pointer, since WebGPU buffers do not have pointers to their locations.
|
||||
static void * const webgpu_ptr_base = (void *)(uintptr_t) 0x1000; // NOLINT
|
||||
|
||||
// Always returns the base offset of a tensor, regardless of views.
|
||||
static uint64_t webgpu_tensor_offset(const ggml_tensor * tensor) {
|
||||
if (tensor->view_src) {
|
||||
return (uint8_t *) tensor->view_src->data - (uint8_t *) webgpu_ptr_base;
|
||||
}
|
||||
return (uint8_t *) tensor->data - (uint8_t *) webgpu_ptr_base;
|
||||
}
|
||||
|
||||
/* Struct definitions */
|
||||
|
||||
// All the base objects needed to run operations on a WebGPU device
|
||||
struct webgpu_context_struct {
|
||||
wgpu::Instance instance;
|
||||
wgpu::Adapter adapter;
|
||||
wgpu::Device device;
|
||||
wgpu::Queue queue;
|
||||
wgpu::Limits limits;
|
||||
wgpu::SupportedFeatures features;
|
||||
|
||||
std::mutex mutex;
|
||||
bool device_initialized = false;
|
||||
|
||||
// pipelines and parameter buffers
|
||||
// TODO: reuse params buffers for different pipelines when possible
|
||||
wgpu::ComputePipeline memset_pipeline;
|
||||
wgpu::Buffer memset_params_dev_buf;
|
||||
wgpu::Buffer memset_params_host_buf;
|
||||
wgpu::ComputePipeline mul_mat_pipeline;
|
||||
wgpu::Buffer mul_mat_params_dev_buf;
|
||||
wgpu::Buffer mul_mat_params_host_buf;
|
||||
wgpu::ComputePipeline cpy_pipeline;
|
||||
wgpu::Buffer cpy_params_dev_buf;
|
||||
wgpu::Buffer cpy_params_host_buf;
|
||||
|
||||
size_t memset_bytes_per_thread;
|
||||
|
||||
// Staging buffer for reading data from the GPU
|
||||
wgpu::Buffer get_tensor_staging_buf;
|
||||
};
|
||||
|
||||
typedef std::shared_ptr<webgpu_context_struct> webgpu_context;
|
||||
|
||||
struct ggml_backend_webgpu_reg_context {
|
||||
webgpu_context webgpu_ctx;
|
||||
|
||||
size_t device_count;
|
||||
const char * name;
|
||||
};
|
||||
|
||||
struct ggml_backend_webgpu_device_context {
|
||||
webgpu_context webgpu_ctx;
|
||||
|
||||
std::string device_name;
|
||||
std::string device_desc;
|
||||
};
|
||||
|
||||
struct ggml_backend_webgpu_context {
|
||||
webgpu_context webgpu_ctx;
|
||||
|
||||
std::string name;
|
||||
};
|
||||
|
||||
struct ggml_backend_webgpu_buffer_context {
|
||||
webgpu_context webgpu_ctx;
|
||||
|
||||
wgpu::Buffer buffer;
|
||||
|
||||
ggml_backend_webgpu_buffer_context(webgpu_context ctx, wgpu::Buffer buf) :
|
||||
webgpu_ctx(ctx), buffer(buf) {
|
||||
}
|
||||
};
|
||||
|
||||
/* End struct definitions */
|
||||
|
||||
/* WebGPU object initializations */
|
||||
|
||||
static void ggml_webgpu_create_pipeline(wgpu::Device &device, wgpu::ComputePipeline &pipeline, const char * shader_code, const char * label, const std::vector<wgpu::ConstantEntry> &constants = {}) {
|
||||
WEBGPU_LOG_DEBUG("ggml_webgpu_create_pipeline()");
|
||||
wgpu::ShaderSourceWGSL shader_source;
|
||||
shader_source.code = shader_code;
|
||||
wgpu::ShaderModuleDescriptor shader_desc;
|
||||
shader_desc.nextInChain = &shader_source;
|
||||
wgpu::ShaderModule shader_module = device.CreateShaderModule(&shader_desc);
|
||||
|
||||
wgpu::ComputePipelineDescriptor pipeline_desc;
|
||||
pipeline_desc.label = label;
|
||||
pipeline_desc.compute.module = shader_module;
|
||||
pipeline_desc.compute.entryPoint = "main"; // Entry point in the WGSL code
|
||||
pipeline_desc.layout = nullptr; // nullptr means auto layout
|
||||
if (constants.size() > 0) {
|
||||
pipeline_desc.compute.constants = constants.data();
|
||||
pipeline_desc.compute.constantCount = constants.size();
|
||||
}
|
||||
pipeline = device.CreateComputePipeline(&pipeline_desc);
|
||||
}
|
||||
|
||||
static void ggml_webgpu_create_buffer(wgpu::Device &device, wgpu::Buffer &buffer, size_t size, wgpu::BufferUsage usage, const char* label) {
|
||||
WEBGPU_LOG_DEBUG("ggml_webgpu_create_buffer()");
|
||||
|
||||
wgpu::BufferDescriptor buffer_desc;
|
||||
buffer_desc.size = size;
|
||||
buffer_desc.usage = usage;
|
||||
buffer_desc.label = label;
|
||||
buffer_desc.mappedAtCreation = false;
|
||||
// TODO: error handling
|
||||
buffer = device.CreateBuffer(&buffer_desc);
|
||||
}
|
||||
|
||||
/** End WebGPU object initializations */
|
||||
|
||||
/** WebGPU Actions */
|
||||
|
||||
static void ggml_backend_webgpu_map_buffer(webgpu_context ctx, wgpu::Buffer buffer, wgpu::MapMode mode, size_t offset, size_t size) {
|
||||
ctx->instance.WaitAny(buffer.MapAsync(
|
||||
mode, offset, size, wgpu::CallbackMode::WaitAnyOnly,
|
||||
[](wgpu::MapAsyncStatus status, wgpu::StringView message) {
|
||||
if (status != wgpu::MapAsyncStatus::Success) {
|
||||
GGML_LOG_ERROR("ggml_webgpu: Failed to map buffer: %s\n", message.data);
|
||||
}
|
||||
}),
|
||||
UINT64_MAX
|
||||
);
|
||||
}
|
||||
|
||||
static void ggml_backend_webgpu_buffer_memset(webgpu_context ctx, wgpu::Buffer buf, uint32_t value, size_t offset, size_t size) {
|
||||
std::lock_guard<std::mutex> lock(ctx->mutex);
|
||||
wgpu::Device device = ctx->device;
|
||||
|
||||
// map the host parameters buffer
|
||||
ggml_backend_webgpu_map_buffer(ctx, ctx->memset_params_host_buf, wgpu::MapMode::Write, 0, ctx->memset_params_host_buf.GetSize());
|
||||
uint32_t * params = (uint32_t *) ctx->memset_params_host_buf.GetMappedRange();
|
||||
|
||||
params[0] = (uint32_t)offset;
|
||||
params[1] = (uint32_t)size;
|
||||
params[2] = value;
|
||||
ctx->memset_params_host_buf.Unmap();
|
||||
|
||||
wgpu::BindGroupEntry entries[2];
|
||||
entries[0].binding = 0; // binding for the buffer to memset
|
||||
entries[0].buffer = buf;
|
||||
entries[0].offset = 0;
|
||||
entries[0].size = buf.GetSize();
|
||||
entries[1].binding = 1; // binding for the parameters
|
||||
entries[1].buffer = ctx->memset_params_dev_buf;
|
||||
entries[1].offset = 0;
|
||||
entries[1].size = ctx->memset_params_dev_buf.GetSize();
|
||||
|
||||
wgpu::BindGroupDescriptor bind_group_desc;
|
||||
bind_group_desc.layout = ctx->memset_pipeline.GetBindGroupLayout(0);
|
||||
bind_group_desc.entryCount = 2;
|
||||
bind_group_desc.label = "ggml_memset";
|
||||
bind_group_desc.entries = entries;
|
||||
wgpu::BindGroup bind_group = device.CreateBindGroup(&bind_group_desc);
|
||||
|
||||
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
|
||||
encoder.CopyBufferToBuffer(
|
||||
ctx->memset_params_host_buf, 0,
|
||||
ctx->memset_params_dev_buf, 0,
|
||||
ctx->memset_params_dev_buf.GetSize()
|
||||
);
|
||||
wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
|
||||
pass.SetPipeline(ctx->memset_pipeline);
|
||||
pass.SetBindGroup(0, bind_group);
|
||||
size_t bytes_per_wg = ctx->limits.maxComputeWorkgroupSizeX * ctx->memset_bytes_per_thread;
|
||||
pass.DispatchWorkgroups(((size + 3) + bytes_per_wg - 1) / bytes_per_wg, 1, 1);
|
||||
pass.End();
|
||||
wgpu::CommandBuffer commands = encoder.Finish();
|
||||
|
||||
ctx->queue.Submit(1, &commands);
|
||||
}
|
||||
|
||||
static void ggml_backend_webgpu_wait_on_submission(webgpu_context ctx) {
|
||||
// Wait for the queue to finish processing all commands
|
||||
ctx->instance.WaitAny(ctx->queue.OnSubmittedWorkDone(wgpu::CallbackMode::WaitAnyOnly,
|
||||
[](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
|
||||
if (status != wgpu::QueueWorkDoneStatus::Success) {
|
||||
GGML_LOG_ERROR("ggml_webgpu: Failed to wait on queue: %s\n", message.data);
|
||||
}
|
||||
}),
|
||||
UINT64_MAX
|
||||
);
|
||||
}
|
||||
|
||||
/** End WebGPU Actions */
|
||||
|
||||
/** GGML Backend Interface */
|
||||
|
||||
static const char * ggml_backend_webgpu_name(ggml_backend_t backend) {
|
||||
ggml_backend_webgpu_context * ctx = (ggml_backend_webgpu_context *)backend->context;
|
||||
return ctx->name.c_str();
|
||||
}
|
||||
|
||||
static void ggml_backend_webgpu_free(ggml_backend_t backend) {
|
||||
ggml_backend_webgpu_context * ctx = (ggml_backend_webgpu_context *)backend->context;
|
||||
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_free(" << ctx->name << ")");
|
||||
|
||||
// TODO: cleanup
|
||||
GGML_UNUSED(ctx);
|
||||
}
|
||||
|
||||
// Returns true if node has enqueued work into the queue, false otherwise
|
||||
static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node){
|
||||
if (ggml_is_empty(node)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
WEBGPU_LOG_DEBUG("ggml_webgpu_encode_node(" << node << ", " << ggml_op_name(node->op) << ")");
|
||||
|
||||
|
||||
switch (node->op) {
|
||||
// no-ops
|
||||
case GGML_OP_NONE:
|
||||
case GGML_OP_VIEW:
|
||||
case GGML_OP_PERMUTE:
|
||||
return false;
|
||||
|
||||
case GGML_OP_CPY: {
|
||||
std::lock_guard<std::mutex> lock(ctx->mutex);
|
||||
const ggml_tensor * src = node->src[0];
|
||||
ggml_backend_webgpu_buffer_context * src_ctx = (ggml_backend_webgpu_buffer_context *) src->buffer->context;
|
||||
size_t src_offset = webgpu_tensor_offset(src) + src->view_offs;
|
||||
// assumes power of 2 offset alignment
|
||||
size_t src_misalignment = src_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1);
|
||||
// align to minimum offset alignment
|
||||
src_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1);
|
||||
ggml_backend_webgpu_buffer_context * dst_ctx = (ggml_backend_webgpu_buffer_context *) node->buffer->context;
|
||||
size_t dst_offset = webgpu_tensor_offset(node) + node->view_offs;
|
||||
size_t dst_misalignment = dst_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1);
|
||||
dst_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1);
|
||||
|
||||
wgpu::Device device = ctx->device;
|
||||
ggml_backend_webgpu_map_buffer(ctx, ctx->cpy_params_host_buf,
|
||||
wgpu::MapMode::Write, 0, ctx->cpy_params_host_buf.GetSize());
|
||||
uint32_t * params = (uint32_t *) ctx->cpy_params_host_buf.GetMappedRange();
|
||||
uint32_t ne = (uint32_t)ggml_nelements(node);
|
||||
params[0] = ne;
|
||||
params[1] = src_misalignment/ggml_type_size(src->type);
|
||||
params[2] = dst_misalignment/ggml_type_size(node->type);
|
||||
|
||||
// Convert byte-strides to element-strides
|
||||
params[3] = (uint32_t)src->nb[0]/ggml_type_size(src->type);
|
||||
params[4] = (uint32_t)src->nb[1]/ggml_type_size(src->type);
|
||||
params[5] = (uint32_t)src->nb[2]/ggml_type_size(src->type);
|
||||
params[6] = (uint32_t)src->nb[3]/ggml_type_size(src->type);
|
||||
params[7] = (uint32_t)node->nb[0]/ggml_type_size(node->type);
|
||||
params[8] = (uint32_t)node->nb[1]/ggml_type_size(node->type);
|
||||
params[9] = (uint32_t)node->nb[2]/ggml_type_size(node->type);
|
||||
params[10] = (uint32_t)node->nb[3]/ggml_type_size(node->type);
|
||||
// Logical shape — same for both tensors even if permuted
|
||||
params[11] = (uint32_t)(src->ne[0]);
|
||||
params[12] = (uint32_t)(src->ne[1]);
|
||||
params[13] = (uint32_t)(src->ne[2]);
|
||||
params[14] = (uint32_t)(src->ne[3]);
|
||||
|
||||
ctx->cpy_params_host_buf.Unmap();
|
||||
|
||||
wgpu::BindGroupEntry entries[3];
|
||||
entries[0].binding = 0;
|
||||
entries[0].buffer = src_ctx->buffer;
|
||||
entries[0].offset = src_offset;
|
||||
entries[0].size = (ggml_nbytes(src) + src_misalignment + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) & ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1);
|
||||
|
||||
entries[1].binding = 1;
|
||||
entries[1].buffer = dst_ctx->buffer;
|
||||
entries[1].offset = dst_offset;
|
||||
entries[1].size = (ggml_nbytes(node) + dst_misalignment + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) & ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1);
|
||||
|
||||
entries[2].binding = 2;
|
||||
entries[2].buffer = ctx->cpy_params_dev_buf;
|
||||
entries[2].offset = 0;
|
||||
entries[2].size = ctx->cpy_params_dev_buf.GetSize();
|
||||
|
||||
wgpu::BindGroupDescriptor bind_group_desc;
|
||||
bind_group_desc.layout = ctx->cpy_pipeline.GetBindGroupLayout(0);
|
||||
bind_group_desc.label = "ggml_op_cpy";
|
||||
bind_group_desc.entryCount = 3;
|
||||
bind_group_desc.entries = entries;
|
||||
wgpu::BindGroup bind_group = device.CreateBindGroup(&bind_group_desc);
|
||||
|
||||
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
|
||||
encoder.CopyBufferToBuffer(
|
||||
ctx->cpy_params_host_buf, 0,
|
||||
ctx->cpy_params_dev_buf, 0,
|
||||
ctx->cpy_params_dev_buf.GetSize()
|
||||
);
|
||||
wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
|
||||
pass.SetPipeline(ctx->cpy_pipeline);
|
||||
pass.SetBindGroup(0, bind_group);
|
||||
size_t max_wg_size = ctx->limits.maxComputeWorkgroupSizeX;
|
||||
pass.DispatchWorkgroups((ne + max_wg_size - 1) / max_wg_size);
|
||||
pass.End();
|
||||
wgpu::CommandBuffer commands = encoder.Finish();
|
||||
|
||||
// TODO, don't submit here, batch submissions
|
||||
ctx->queue.Submit(1, &commands);
|
||||
// TODO, don't wait on submission here
|
||||
ggml_backend_webgpu_wait_on_submission(ctx);
|
||||
return true;
|
||||
}
|
||||
|
||||
case GGML_OP_MUL_MAT:
|
||||
{
|
||||
const ggml_tensor * src0 = node->src[0];
|
||||
ggml_backend_webgpu_buffer_context * src0_ctx = (ggml_backend_webgpu_buffer_context *) src0->buffer->context;
|
||||
size_t src0_offset = webgpu_tensor_offset(src0) + src0->view_offs;
|
||||
const ggml_tensor * src1 = node->src[1];
|
||||
ggml_backend_webgpu_buffer_context * src1_ctx = (ggml_backend_webgpu_buffer_context *) src1->buffer->context;
|
||||
size_t src1_offset = webgpu_tensor_offset(src1) + src1->view_offs;
|
||||
ggml_backend_webgpu_buffer_context * dst_ctx = (ggml_backend_webgpu_buffer_context *) node->buffer->context;
|
||||
|
||||
size_t dst_offset = webgpu_tensor_offset(node) + node->view_offs;
|
||||
|
||||
wgpu::Device device = ctx->device;
|
||||
|
||||
// map the host parameters buffer
|
||||
ggml_backend_webgpu_map_buffer(ctx, ctx->mul_mat_params_host_buf,
|
||||
wgpu::MapMode::Write, 0, ctx->mul_mat_params_host_buf.GetSize());
|
||||
uint32_t * params = (uint32_t *) ctx->mul_mat_params_host_buf.GetMappedRange();
|
||||
|
||||
params[0] = (uint32_t)node->ne[1]; // number of rows in result (M)
|
||||
params[1] = (uint32_t)node->ne[0]; // number of columns in result (N)
|
||||
params[2] = (uint32_t)src0->ne[0]; // number of columns in src0/src1 (K)
|
||||
|
||||
params[3] = (uint32_t)src0->nb[1]/ggml_type_size(src0->type); // stride (elements) of src0 in dimension 1
|
||||
params[4] = (uint32_t)src1->nb[1]/ggml_type_size(src1->type); // stride (elements) of src1 in dimension 1
|
||||
params[5] = (uint32_t)src0->nb[2]/ggml_type_size(src0->type); // stride (elements) of src0 in dimension 2
|
||||
params[6] = (uint32_t)src1->nb[2]/ggml_type_size(src1->type); // stride (elements) of src1 in dimension 2
|
||||
params[7] = (uint32_t)src0->nb[3]/ggml_type_size(src0->type); // stride (elements) of src0 in dimension 3
|
||||
params[8] = (uint32_t)src1->nb[3]/ggml_type_size(src1->type); // stride (elements) of src1 in dimension 3
|
||||
|
||||
params[9] = (uint32_t)src0->ne[2]; // batch size in dimension 2
|
||||
params[10] = (uint32_t)src0->ne[3]; // batch size in dimension 3
|
||||
params[11] = (uint32_t)(src1->ne[2]/src0->ne[2]); // broadcast in dimension 2
|
||||
params[12] = (uint32_t)(src1->ne[3]/src0->ne[3]); // broadcast in dimension 3
|
||||
|
||||
ctx->mul_mat_params_host_buf.Unmap();
|
||||
|
||||
wgpu::BindGroupEntry entries[4];
|
||||
entries[0].binding = 0;
|
||||
entries[0].buffer = src0_ctx->buffer;
|
||||
entries[0].offset = src0_offset;
|
||||
entries[0].size = ggml_nbytes(src0);
|
||||
|
||||
entries[1].binding = 1;
|
||||
entries[1].buffer = src1_ctx->buffer;
|
||||
entries[1].offset = src1_offset;
|
||||
entries[1].size = ggml_nbytes(src1);
|
||||
|
||||
entries[2].binding = 2;
|
||||
entries[2].buffer = dst_ctx->buffer;
|
||||
entries[2].offset = dst_offset;
|
||||
entries[2].size = ggml_nbytes(node);
|
||||
|
||||
entries[3].binding = 3;
|
||||
entries[3].buffer = ctx->mul_mat_params_dev_buf;
|
||||
entries[3].offset = 0;
|
||||
entries[3].size = ctx->mul_mat_params_dev_buf.GetSize();
|
||||
|
||||
wgpu::BindGroupDescriptor bind_group_desc;
|
||||
bind_group_desc.layout = ctx->mul_mat_pipeline.GetBindGroupLayout(0);
|
||||
bind_group_desc.entryCount = 4;
|
||||
bind_group_desc.label = "ggml_op_mul_mat";
|
||||
bind_group_desc.entries = entries;
|
||||
wgpu::BindGroup bind_group = device.CreateBindGroup(&bind_group_desc);
|
||||
|
||||
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
|
||||
encoder.CopyBufferToBuffer(
|
||||
ctx->mul_mat_params_host_buf, 0,
|
||||
ctx->mul_mat_params_dev_buf, 0,
|
||||
ctx->mul_mat_params_dev_buf.GetSize()
|
||||
);
|
||||
wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
|
||||
pass.SetPipeline(ctx->mul_mat_pipeline);
|
||||
pass.SetBindGroup(0, bind_group);
|
||||
pass.DispatchWorkgroups((node->ne[0] * node->ne[1] * node->ne[2] * node->ne[3] + WEBGPU_MUL_MAT_WG_SIZE - 1) / WEBGPU_MUL_MAT_WG_SIZE);
|
||||
pass.End();
|
||||
wgpu::CommandBuffer commands = encoder.Finish();
|
||||
|
||||
// TODO, don't submit here, batch submissions
|
||||
ctx->queue.Submit(1, &commands);
|
||||
// TODO, don't wait on submission here
|
||||
ggml_backend_webgpu_wait_on_submission(ctx);
|
||||
return true;
|
||||
}
|
||||
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
|
||||
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_graph_compute(" << cgraph->n_nodes << " nodes)");
|
||||
|
||||
ggml_backend_webgpu_context * backend_ctx = static_cast<ggml_backend_webgpu_context *>(backend->context);
|
||||
webgpu_context ctx = backend_ctx->webgpu_ctx;
|
||||
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
ggml_webgpu_encode_node(ctx, cgraph->nodes[i]);
|
||||
}
|
||||
|
||||
return GGML_STATUS_SUCCESS;
|
||||
}
|
||||
|
||||
static ggml_backend_i ggml_backend_webgpu_i = {
|
||||
/* .get_name = */ ggml_backend_webgpu_name,
|
||||
/* .free = */ ggml_backend_webgpu_free,
|
||||
/* .set_tensor_async = */ NULL,
|
||||
/* .get_tensor_async = */ NULL,
|
||||
/* .cpy_tensor_async = */ NULL,
|
||||
/* .synchronize = */ NULL,
|
||||
/* .graph_plan_create = */ NULL,
|
||||
/* .graph_plan_free = */ NULL,
|
||||
/* .graph_plan_update = */ NULL,
|
||||
/* .graph_plan_compute = */ NULL,
|
||||
/* .graph_compute = */ ggml_backend_webgpu_graph_compute,
|
||||
/* .event_record = */ NULL,
|
||||
/* .event_wait = */ NULL,
|
||||
};
|
||||
|
||||
/* End GGML Backend Interface */
|
||||
|
||||
/* GGML Backend Buffer Interface */
|
||||
|
||||
static void ggml_backend_webgpu_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
||||
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_free_buffer()");
|
||||
ggml_backend_webgpu_buffer_context * ctx = static_cast<ggml_backend_webgpu_buffer_context *>(buffer->context);
|
||||
ctx->buffer.Destroy();
|
||||
}
|
||||
|
||||
// Returns the "fake" base pointer.
|
||||
static void * ggml_backend_webgpu_buffer_get_base(ggml_backend_buffer_t buffer) {
|
||||
GGML_UNUSED(buffer);
|
||||
return webgpu_ptr_base;
|
||||
}
|
||||
|
||||
static void ggml_backend_webgpu_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
|
||||
if (size == 0) {
|
||||
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor: size is zero, nothing to do.");
|
||||
return;
|
||||
}
|
||||
|
||||
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor(" << buffer << ", " << tensor << ", " << value << ", " << offset << ", " << size << ")");
|
||||
|
||||
ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
|
||||
size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
|
||||
// This is a trick to set all bytes of a u32 to the same 1 byte value.
|
||||
uint32_t val32 = (uint32_t)value * 0x01010101;
|
||||
ggml_backend_webgpu_buffer_memset(buf_ctx->webgpu_ctx, buf_ctx->buffer, val32, total_offset, size);
|
||||
}
|
||||
|
||||
static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
||||
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_set_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")");
|
||||
ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
|
||||
webgpu_context webgpu_ctx = buf_ctx->webgpu_ctx;
|
||||
|
||||
size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
|
||||
|
||||
webgpu_ctx->queue.WriteBuffer(buf_ctx->buffer, total_offset, data, (size/4)*4);
|
||||
|
||||
if (size % 4 != 0) {
|
||||
// If size is not a multiple of 4, we need to memset the remaining bytes
|
||||
size_t remaining_size = size % 4;
|
||||
// pack the remaining bytes into a uint32_t
|
||||
uint32_t val32 = 0;
|
||||
for (size_t i = 0; i < remaining_size; i++) {
|
||||
((uint8_t *)&val32)[i] = ((const uint8_t *)data)[size - remaining_size + i];
|
||||
}
|
||||
// memset the remaining bytes
|
||||
ggml_backend_webgpu_buffer_memset(webgpu_ctx, buf_ctx->buffer, val32, total_offset + (size - remaining_size), remaining_size);
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
|
||||
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_get_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")");
|
||||
|
||||
ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
|
||||
webgpu_context webgpu_ctx = buf_ctx->webgpu_ctx;
|
||||
wgpu::Device device = webgpu_ctx->device;
|
||||
|
||||
size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
|
||||
|
||||
size_t final_size = size;
|
||||
if (size % 4 != 0) {
|
||||
// If size is not a multiple of 4, we need to round it up to the next multiple of 4
|
||||
final_size = size + (4 - (size % 4));
|
||||
}
|
||||
|
||||
std::lock_guard<std::mutex> lock(webgpu_ctx->mutex);
|
||||
|
||||
if (webgpu_ctx->get_tensor_staging_buf == nullptr ||
|
||||
webgpu_ctx->get_tensor_staging_buf.GetSize() < final_size) {
|
||||
// Create a new staging buffer if it doesn't exist or is too small
|
||||
if (webgpu_ctx->get_tensor_staging_buf) {
|
||||
webgpu_ctx->get_tensor_staging_buf.Destroy();
|
||||
}
|
||||
ggml_webgpu_create_buffer(device, webgpu_ctx->get_tensor_staging_buf, final_size,
|
||||
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "get_tensor_staging_buf");
|
||||
}
|
||||
|
||||
// Copy the data from the buffer to the staging buffer
|
||||
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
|
||||
encoder.CopyBufferToBuffer(buf_ctx->buffer, total_offset, webgpu_ctx->get_tensor_staging_buf, 0, final_size);
|
||||
wgpu::CommandBuffer commands = encoder.Finish();
|
||||
// Submit the command buffer to the queue
|
||||
webgpu_ctx->queue.Submit(1, &commands);
|
||||
|
||||
// Map the staging buffer to read the data
|
||||
ggml_backend_webgpu_map_buffer(webgpu_ctx, webgpu_ctx->get_tensor_staging_buf, wgpu::MapMode::Read, 0, final_size);
|
||||
// Must specify size here since the staging buffer might be larger than the tensor size
|
||||
const void * mapped_range = webgpu_ctx->get_tensor_staging_buf.GetConstMappedRange(0, final_size);
|
||||
|
||||
// Copy the data from the mapped range to the output buffer
|
||||
std::memcpy(data, mapped_range, size);
|
||||
webgpu_ctx->get_tensor_staging_buf.Unmap();
|
||||
}
|
||||
|
||||
static void ggml_backend_webgpu_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
|
||||
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_clear(" << buffer << ", " << (uint32_t) value << ")");
|
||||
|
||||
ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
|
||||
ggml_backend_webgpu_buffer_memset(buf_ctx->webgpu_ctx, buf_ctx->buffer, value, 0, buffer->size);
|
||||
}
|
||||
|
||||
static ggml_backend_buffer_i ggml_backend_webgpu_buffer_interface = {
|
||||
/* .free_buffer = */ ggml_backend_webgpu_buffer_free_buffer,
|
||||
/* .get_base = */ ggml_backend_webgpu_buffer_get_base,
|
||||
/* .init_tensor = */ NULL, // TODO: optional, needed?
|
||||
/* .memset_tensor = */ ggml_backend_webgpu_buffer_memset_tensor,
|
||||
/* .set_tensor = */ ggml_backend_webgpu_buffer_set_tensor,
|
||||
/* .get_tensor = */ ggml_backend_webgpu_buffer_get_tensor,
|
||||
/* .cpy_tensor = */ NULL, // TODO: optional, implement this
|
||||
/* .clear = */ ggml_backend_webgpu_buffer_clear,
|
||||
/* .reset = */ NULL, // TODO: optional, think it coordinates with .init_tensor
|
||||
};
|
||||
|
||||
/* End GGML Backend Buffer Interface */
|
||||
|
||||
/* GGML Backend Buffer Type Interface */
|
||||
|
||||
static const char * ggml_backend_webgpu_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
|
||||
ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
|
||||
return ctx->device_name.c_str();
|
||||
}
|
||||
|
||||
static ggml_backend_buffer_t ggml_backend_webgpu_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
|
||||
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_type_alloc_buffer(" << size << ")");
|
||||
ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
|
||||
|
||||
wgpu::Buffer buf;
|
||||
ggml_webgpu_create_buffer(ctx->webgpu_ctx->device, buf, size,
|
||||
wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst, "allocated_buffer");
|
||||
|
||||
ggml_backend_webgpu_buffer_context * buf_ctx = new ggml_backend_webgpu_buffer_context(ctx->webgpu_ctx, buf);
|
||||
|
||||
return ggml_backend_buffer_init(buft, ggml_backend_webgpu_buffer_interface, buf_ctx, size);
|
||||
}
|
||||
|
||||
static size_t ggml_backend_webgpu_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
|
||||
ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
|
||||
return ctx->webgpu_ctx->limits.minStorageBufferOffsetAlignment;
|
||||
}
|
||||
|
||||
// maxBufferSize might be larger, but you can't bind more than maxStorageBufferBindingSize to a single binding.
|
||||
static size_t ggml_backend_webgpu_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
|
||||
ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
|
||||
return ctx->webgpu_ctx->limits.maxStorageBufferBindingSize;
|
||||
}
|
||||
|
||||
/* End GGML Backend Buffer Type Interface */
|
||||
|
||||
/* GGML Backend Device Interface */
|
||||
|
||||
static const char * ggml_backend_webgpu_device_get_name(ggml_backend_dev_t dev) {
|
||||
ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
|
||||
return ctx->device_name.c_str();
|
||||
}
|
||||
|
||||
static const char * ggml_backend_webgpu_device_get_description(ggml_backend_dev_t dev) {
|
||||
ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
|
||||
return ctx->device_desc.c_str();
|
||||
}
|
||||
|
||||
static void ggml_backend_webgpu_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
|
||||
ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
|
||||
// TODO: what do we actually want to return here? maxBufferSize might not be the full available memory.
|
||||
*free = ctx->webgpu_ctx->limits.maxBufferSize;
|
||||
*total = ctx->webgpu_ctx->limits.maxBufferSize;
|
||||
}
|
||||
|
||||
static enum ggml_backend_dev_type ggml_backend_webgpu_device_get_type(ggml_backend_dev_t dev) {
|
||||
GGML_UNUSED(dev);
|
||||
return GGML_BACKEND_DEVICE_TYPE_GPU;
|
||||
}
|
||||
|
||||
static void ggml_backend_webgpu_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
|
||||
props->name = ggml_backend_webgpu_device_get_name(dev);
|
||||
props->description = ggml_backend_webgpu_device_get_description(dev);
|
||||
props->type = ggml_backend_webgpu_device_get_type(dev);
|
||||
ggml_backend_webgpu_device_get_memory(dev, &props->memory_free, &props->memory_total);
|
||||
props->caps = {
|
||||
/* .async = */ false,
|
||||
/* .host_buffer = */ false,
|
||||
/* .buffer_from_host_ptr = */ false,
|
||||
/* .events = */ false,
|
||||
};
|
||||
}
|
||||
|
||||
static ggml_guid_t ggml_backend_webgpu_guid(void) {
|
||||
static const char * guid_str = "__ggml_webgpu :)";
|
||||
return reinterpret_cast<ggml_guid_t>((void *)guid_str);
|
||||
}
|
||||
|
||||
static void ggml_webgpu_init_memset_pipeline(webgpu_context webgpu_ctx) {
|
||||
// we use the maximum workgroup size for the memset pipeline
|
||||
size_t max_wg_size = webgpu_ctx->limits.maxComputeWorkgroupSizeX;
|
||||
size_t max_threads = max_wg_size * webgpu_ctx->limits.maxComputeWorkgroupsPerDimension;
|
||||
// Size the bytes_per_thread so that the largest buffer size can be handled
|
||||
webgpu_ctx->memset_bytes_per_thread = (webgpu_ctx->limits.maxStorageBufferBindingSize + max_threads - 1) / max_threads;
|
||||
std::vector<wgpu::ConstantEntry> constants(2);
|
||||
constants[0].key = "wg_size";
|
||||
constants[0].value = max_wg_size;
|
||||
constants[1].key = "bytes_per_thread";
|
||||
constants[1].value = webgpu_ctx->memset_bytes_per_thread;
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->memset_pipeline, wgsl_memset, "memset", constants);
|
||||
ggml_webgpu_create_buffer(webgpu_ctx->device, webgpu_ctx->memset_params_dev_buf,
|
||||
3 * sizeof(uint32_t), // 3 parameters: buffer size, offset, value
|
||||
wgpu::BufferUsage::Uniform | wgpu::BufferUsage::CopyDst, "memset_params_dev_buf");
|
||||
ggml_webgpu_create_buffer(webgpu_ctx->device, webgpu_ctx->memset_params_host_buf,
|
||||
3 * sizeof(uint32_t), wgpu::BufferUsage::MapWrite | wgpu::BufferUsage::CopySrc, "memset_params_host_buf");
|
||||
}
|
||||
|
||||
static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context webgpu_ctx) {
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline, wgsl_mul_mat, "mul_mat");
|
||||
ggml_webgpu_create_buffer(webgpu_ctx->device, webgpu_ctx->mul_mat_params_dev_buf, WEBGPU_MUL_MAT_PARAMS_SIZE,
|
||||
wgpu::BufferUsage::Uniform | wgpu::BufferUsage::CopyDst, "mul_mat_params_dev_buf");
|
||||
ggml_webgpu_create_buffer(webgpu_ctx->device, webgpu_ctx->mul_mat_params_host_buf, WEBGPU_MUL_MAT_PARAMS_SIZE,
|
||||
wgpu::BufferUsage::MapWrite | wgpu::BufferUsage::CopySrc, "mul_mat_params_host_buf");
|
||||
}
|
||||
|
||||
static void ggml_webgpu_init_cpy_pipeline(webgpu_context webgpu_ctx) {
|
||||
std::vector<wgpu::ConstantEntry> constants(1);
|
||||
constants[0].key = "wg_size";
|
||||
constants[0].value = webgpu_ctx->limits.maxComputeWorkgroupSizeX;
|
||||
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline, wgsl_cpy, "cpy", constants);
|
||||
ggml_webgpu_create_buffer(webgpu_ctx->device, webgpu_ctx->cpy_params_dev_buf, WEBGPU_CPY_PARAMS_SIZE,
|
||||
wgpu::BufferUsage::Uniform | wgpu::BufferUsage::CopyDst, "cpy_params_dev_buf");
|
||||
ggml_webgpu_create_buffer(webgpu_ctx->device, webgpu_ctx->cpy_params_host_buf, WEBGPU_CPY_PARAMS_SIZE,
|
||||
wgpu::BufferUsage::MapWrite | wgpu::BufferUsage::CopySrc, "cpy_params_host_buf");
|
||||
}
|
||||
|
||||
// TODO: Make thread safe if multiple devices are used
|
||||
static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, const char * params) {
|
||||
GGML_UNUSED(params);
|
||||
|
||||
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_device_init()");
|
||||
|
||||
ggml_backend_webgpu_device_context * dev_ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
|
||||
webgpu_context webgpu_ctx = dev_ctx->webgpu_ctx;
|
||||
|
||||
std::lock_guard<std::mutex> lock(webgpu_ctx->mutex);
|
||||
|
||||
if (!webgpu_ctx->device_initialized) {
|
||||
// Initialize device
|
||||
wgpu::DeviceDescriptor dev_desc;
|
||||
dev_desc.requiredLimits = &webgpu_ctx->limits;
|
||||
dev_desc.requiredFeatures = webgpu_ctx->features.features;
|
||||
dev_desc.requiredFeatureCount = webgpu_ctx->features.featureCount;
|
||||
dev_desc.SetDeviceLostCallback(wgpu::CallbackMode::AllowSpontaneous,
|
||||
[](const wgpu::Device& device, wgpu::DeviceLostReason reason, wgpu::StringView message) {
|
||||
GGML_UNUSED(device);
|
||||
GGML_LOG_ERROR("ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast<int>(reason), message.data);
|
||||
});
|
||||
dev_desc.SetUncapturedErrorCallback(
|
||||
[](const wgpu::Device& device, wgpu::ErrorType reason, wgpu::StringView message) {
|
||||
GGML_UNUSED(device);
|
||||
GGML_LOG_ERROR("ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast<int>(reason), message.data);
|
||||
});
|
||||
webgpu_ctx->instance.WaitAny(webgpu_ctx->adapter.RequestDevice(&dev_desc, wgpu::CallbackMode::WaitAnyOnly,
|
||||
[webgpu_ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) {
|
||||
if (status != wgpu::RequestDeviceStatus::Success) {
|
||||
GGML_LOG_ERROR("ggml_webgpu: Failed to get a device: %s\n", message.data);
|
||||
return;
|
||||
}
|
||||
webgpu_ctx->device = device;
|
||||
}),
|
||||
UINT64_MAX
|
||||
);
|
||||
GGML_ASSERT(webgpu_ctx->device != nullptr);
|
||||
|
||||
// Initialize (compute) queue
|
||||
webgpu_ctx->queue = webgpu_ctx->device.GetQueue();
|
||||
|
||||
ggml_webgpu_init_memset_pipeline(webgpu_ctx);
|
||||
ggml_webgpu_init_mul_mat_pipeline(webgpu_ctx);
|
||||
ggml_webgpu_init_cpy_pipeline(webgpu_ctx);
|
||||
webgpu_ctx->device_initialized = true;
|
||||
}
|
||||
|
||||
static ggml_backend_webgpu_context backend_ctx;
|
||||
backend_ctx.name = GGML_WEBGPU_NAME + std::string(": ") + dev_ctx->device_name;
|
||||
backend_ctx.webgpu_ctx = webgpu_ctx;
|
||||
|
||||
// See GGML Backend Interface section
|
||||
static ggml_backend backend = {
|
||||
/* .guid = */ ggml_backend_webgpu_guid(),
|
||||
/* .interface = */ ggml_backend_webgpu_i,
|
||||
/* .device = */ dev,
|
||||
/* .context = */ &backend_ctx,
|
||||
};
|
||||
|
||||
return &backend;
|
||||
}
|
||||
|
||||
static ggml_backend_buffer_type_t ggml_backend_webgpu_device_get_buffer_type(ggml_backend_dev_t dev) {
|
||||
// See GGML Backend Buffer Type Interface section
|
||||
static struct ggml_backend_buffer_type ggml_backend_webgpu_buffer_type = {
|
||||
/* .iface = */ {
|
||||
/* .get_name = */ ggml_backend_webgpu_buffer_type_get_name,
|
||||
/* .alloc_buffer = */ ggml_backend_webgpu_buffer_type_alloc_buffer,
|
||||
/* .get_alignment = */ ggml_backend_webgpu_buffer_type_get_alignment,
|
||||
/* .get_max_size = */ ggml_backend_webgpu_buffer_type_get_max_size,
|
||||
/* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
|
||||
/* .is_host = */ NULL, // defaults to false
|
||||
},
|
||||
/* .device = */ dev,
|
||||
/* .context = */ NULL,
|
||||
};
|
||||
|
||||
return &ggml_backend_webgpu_buffer_type;
|
||||
}
|
||||
|
||||
static bool ggml_backend_webgpu_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
|
||||
GGML_UNUSED(dev);
|
||||
return buft->iface.get_name == ggml_backend_webgpu_buffer_type_get_name;
|
||||
}
|
||||
|
||||
static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
|
||||
GGML_UNUSED(dev);
|
||||
|
||||
switch (op->op) {
|
||||
case GGML_OP_NONE:
|
||||
case GGML_OP_VIEW:
|
||||
case GGML_OP_PERMUTE:
|
||||
return true;
|
||||
case GGML_OP_CPY:
|
||||
return op->type == GGML_TYPE_F16 && op->src[0]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_MUL_MAT:
|
||||
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
static struct ggml_backend_device_i ggml_backend_webgpu_device_i = {
|
||||
/* .get_name = */ ggml_backend_webgpu_device_get_name,
|
||||
/* .get_description = */ ggml_backend_webgpu_device_get_description,
|
||||
/* .get_memory = */ ggml_backend_webgpu_device_get_memory,
|
||||
/* .get_type = */ ggml_backend_webgpu_device_get_type,
|
||||
/* .get_props = */ ggml_backend_webgpu_device_get_props,
|
||||
/* .init_backend = */ ggml_backend_webgpu_device_init,
|
||||
/* .get_buffer_type = */ ggml_backend_webgpu_device_get_buffer_type,
|
||||
/* .get_host_buffer_type = */ NULL,
|
||||
/* .buffer_from_host_ptr = */ NULL,
|
||||
/* .supports_op = */ ggml_backend_webgpu_device_supports_op,
|
||||
/* .supports_buft = */ ggml_backend_webgpu_device_supports_buft,
|
||||
/* .offload_op = */ NULL,
|
||||
/* .event_new = */ NULL,
|
||||
/* .event_free = */ NULL,
|
||||
/* .event_synchronize = */ NULL,
|
||||
};
|
||||
|
||||
/* End GGML Backend Device Interface */
|
||||
|
||||
/* GGML Backend Registration Interface */
|
||||
|
||||
static const char * ggml_backend_webgpu_reg_get_name(ggml_backend_reg_t reg) {
|
||||
ggml_backend_webgpu_reg_context * ctx = static_cast<ggml_backend_webgpu_reg_context *>(reg->context);
|
||||
return ctx->name;
|
||||
}
|
||||
|
||||
static size_t ggml_backend_webgpu_reg_get_device_count(ggml_backend_reg_t reg) {
|
||||
ggml_backend_webgpu_reg_context * ctx = static_cast<ggml_backend_webgpu_reg_context *>(reg->context);
|
||||
return ctx->device_count;
|
||||
}
|
||||
|
||||
// TODO: Does this need to be thread safe? Is it only called once?
|
||||
// Only one device is supported for now
|
||||
static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t reg, size_t index) {
|
||||
GGML_ASSERT(index == 0);
|
||||
WEBGPU_LOG_DEBUG("ggml_backend_reg_get_device()");
|
||||
|
||||
ggml_backend_webgpu_reg_context * reg_ctx = static_cast<ggml_backend_webgpu_reg_context *>(reg->context);
|
||||
|
||||
webgpu_context ctx = reg_ctx->webgpu_ctx;
|
||||
|
||||
wgpu::RequestAdapterOptions options = {};
|
||||
auto callback = [](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char *message, void *userdata) {
|
||||
if (status != wgpu::RequestAdapterStatus::Success) {
|
||||
GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message);
|
||||
return;
|
||||
}
|
||||
*static_cast<wgpu::Adapter *>(userdata) = adapter;
|
||||
};
|
||||
void *userdata = &ctx->adapter;
|
||||
ctx->instance.WaitAny(ctx->instance.RequestAdapter(&options, wgpu::CallbackMode::WaitAnyOnly, callback, userdata), UINT64_MAX);
|
||||
GGML_ASSERT(ctx->adapter != nullptr);
|
||||
|
||||
ctx->adapter.GetLimits(&ctx->limits);
|
||||
ctx->adapter.GetFeatures(&ctx->features);
|
||||
|
||||
wgpu::AdapterInfo info{};
|
||||
ctx->adapter.GetInfo(&info);
|
||||
|
||||
static ggml_backend_webgpu_device_context device_ctx;
|
||||
device_ctx.webgpu_ctx = ctx;
|
||||
device_ctx.device_name = GGML_WEBGPU_NAME;
|
||||
device_ctx.device_desc = std::string(info.description.data);
|
||||
|
||||
GGML_LOG_INFO("ggml_webgpu: adapter_info: vendor_id: %u | vendor: %s | architecture: %s | device_id: %u | name: %s | device_desc: %s\n",
|
||||
info.vendorID, info.vendor.data, info.architecture.data, info.deviceID, info.device.data, info.description.data);
|
||||
|
||||
// See GGML Backend Device Interface section
|
||||
static ggml_backend_device device = {
|
||||
/* .iface = */ ggml_backend_webgpu_device_i,
|
||||
/* .reg = */ reg,
|
||||
/* .context = */ &device_ctx,
|
||||
};
|
||||
return &device;
|
||||
}
|
||||
|
||||
|
||||
static const struct ggml_backend_reg_i ggml_backend_webgpu_reg_i = {
|
||||
/* .get_name = */ ggml_backend_webgpu_reg_get_name,
|
||||
/* .get_device_count = */ ggml_backend_webgpu_reg_get_device_count,
|
||||
/* .get_device = */ ggml_backend_webgpu_reg_get_device,
|
||||
/* .get_proc_address = */ NULL,
|
||||
};
|
||||
|
||||
/* End GGML Backend Registration Interface */
|
||||
|
||||
// TODO: Does this need to be thread safe? Is it only called once?
|
||||
ggml_backend_reg_t ggml_backend_webgpu_reg() {
|
||||
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_reg()");
|
||||
|
||||
webgpu_context webgpu_ctx = std::make_shared<webgpu_context_struct>();
|
||||
webgpu_ctx->device_initialized = false;
|
||||
|
||||
static ggml_backend_webgpu_reg_context ctx;
|
||||
ctx.webgpu_ctx = webgpu_ctx;
|
||||
ctx.name = GGML_WEBGPU_NAME;
|
||||
ctx.device_count = 1;
|
||||
|
||||
wgpu::InstanceDescriptor instance_descriptor{};
|
||||
std::vector<wgpu::InstanceFeatureName> instance_features = {wgpu::InstanceFeatureName::TimedWaitAny};
|
||||
instance_descriptor.requiredFeatures = instance_features.data();
|
||||
instance_descriptor.requiredFeatureCount = instance_features.size();
|
||||
webgpu_ctx->instance = wgpu::CreateInstance(&instance_descriptor);
|
||||
GGML_ASSERT(webgpu_ctx->instance != nullptr);
|
||||
|
||||
static ggml_backend_reg reg = {
|
||||
/* .api_version = */ GGML_BACKEND_API_VERSION,
|
||||
/* .iface = */ ggml_backend_webgpu_reg_i,
|
||||
/* .context = */ &ctx,
|
||||
};
|
||||
return ®
|
||||
}
|
||||
|
||||
ggml_backend_t ggml_backend_webgpu_init(void) {
|
||||
ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_webgpu_reg(), 0);
|
||||
|
||||
return ggml_backend_webgpu_device_init(dev, nullptr);
|
||||
}
|
||||
|
||||
GGML_BACKEND_DL_IMPL(ggml_backend_webgpu_reg)
|
60
ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl
Normal file
60
ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl
Normal file
|
@ -0,0 +1,60 @@
|
|||
enable f16;
|
||||
|
||||
@group(0) @binding(0)
|
||||
var<storage, read_write> src: array<f32>;
|
||||
|
||||
@group(0) @binding(1)
|
||||
var<storage, read_write> dst: array<f16>;
|
||||
|
||||
struct Params {
|
||||
ne: u32, // total number of elements
|
||||
offset_src: u32, // in elements
|
||||
offset_dst: u32, // in elements
|
||||
|
||||
// Strides (in elements) — may be permuted
|
||||
stride_src0: u32,
|
||||
stride_src1: u32,
|
||||
stride_src2: u32,
|
||||
stride_src3: u32,
|
||||
|
||||
stride_dst0: u32,
|
||||
stride_dst1: u32,
|
||||
stride_dst2: u32,
|
||||
stride_dst3: u32,
|
||||
|
||||
// Logical shape (same for both tensors)
|
||||
ne0: u32,
|
||||
ne1: u32,
|
||||
ne2: u32,
|
||||
ne3: u32,
|
||||
};
|
||||
|
||||
@group(0) @binding(2)
|
||||
var<uniform> params: Params;
|
||||
|
||||
override wg_size: u32;
|
||||
@compute @workgroup_size(wg_size)
|
||||
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
if (gid.x >= params.ne) {
|
||||
return;
|
||||
}
|
||||
|
||||
var i = gid.x;
|
||||
|
||||
let i3 = i / (params.ne2 * params.ne1 * params.ne0);
|
||||
i = i % (params.ne2 * params.ne1 * params.ne0);
|
||||
|
||||
let i2 = i / (params.ne1 * params.ne0);
|
||||
i = i % (params.ne1 * params.ne0);
|
||||
|
||||
let i1 = i / params.ne0;
|
||||
let i0 = i % params.ne0;
|
||||
|
||||
let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 +
|
||||
i2 * params.stride_src2 + i3 * params.stride_src3;
|
||||
|
||||
let dst_idx = i0 * params.stride_dst0 + i1 * params.stride_dst1 +
|
||||
i2 * params.stride_dst2 + i3 * params.stride_dst3;
|
||||
|
||||
dst[params.offset_dst + dst_idx] = f16(src[params.offset_src + src_idx]);
|
||||
}
|
35
ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py
Executable file
35
ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py
Executable file
|
@ -0,0 +1,35 @@
|
|||
import os
|
||||
import argparse
|
||||
|
||||
|
||||
def escape_triple_quotes(wgsl):
|
||||
# Simple defense in case of embedded """
|
||||
return wgsl.replace('"""', '\\"""')
|
||||
|
||||
|
||||
def to_cpp_string_literal(varname, content):
|
||||
return f'const char* wgsl_{varname} = R"({content})";\n'
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--input', required=True)
|
||||
parser.add_argument('--output', required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
with open(args.output, 'w', encoding='utf-8') as out:
|
||||
out.write("// Auto-generated shader embedding \n\n")
|
||||
for fname in sorted(os.listdir(args.input)):
|
||||
if not fname.endswith('.wgsl'):
|
||||
continue
|
||||
shader_path = os.path.join(args.input, fname)
|
||||
varname = os.path.splitext(fname)[0]
|
||||
with open(shader_path, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
content = escape_triple_quotes(content)
|
||||
out.write(to_cpp_string_literal(varname, content))
|
||||
out.write('\n')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
40
ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl
Normal file
40
ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl
Normal file
|
@ -0,0 +1,40 @@
|
|||
@group(0) @binding(0)
|
||||
var<storage, read_write> output_buffer: array<u32>;
|
||||
|
||||
struct Params {
|
||||
offset: u32, // in bytes
|
||||
size: u32, // in bytes
|
||||
value: u32, // 4 8-bit values, which are either repeating (memset_tensor) or may be separate (cleaning up unaligned set_tensor operations)
|
||||
};
|
||||
|
||||
@group(0) @binding(1)
|
||||
var<uniform> params: Params;
|
||||
|
||||
override wg_size: u32;
|
||||
override bytes_per_thread: u32;
|
||||
|
||||
@compute @workgroup_size(wg_size)
|
||||
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
let i = gid.x * bytes_per_thread;
|
||||
let start = params.offset;
|
||||
let end = params.offset + params.size;
|
||||
|
||||
for (var j: u32 = 0u; j < bytes_per_thread; j = j + 1u) {
|
||||
let byte_index = start + i + j;
|
||||
if (byte_index + 4u <= end) {
|
||||
output_buffer[(byte_index >> 2u)] = params.value;
|
||||
} else {
|
||||
// Handle tail (unaligned)
|
||||
for (var k: u32 = 0u; k < 4u; k = k + 1u) {
|
||||
let idx = byte_index + k;
|
||||
if (idx < end) {
|
||||
let word_idx = idx >> 2u;
|
||||
let byte_offset = (idx & 3u) * 8u;
|
||||
let mask = ~(0xffu << byte_offset);
|
||||
let existing = output_buffer[word_idx];
|
||||
output_buffer[word_idx] = (existing & mask) | ((params.value & 0xffu) << byte_offset);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
56
ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl
Normal file
56
ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl
Normal file
|
@ -0,0 +1,56 @@
|
|||
struct MulMatParams {
|
||||
m: u32,
|
||||
n: u32,
|
||||
k: u32,
|
||||
// all strides are in elements
|
||||
stride_01: u32,
|
||||
stride_11: u32,
|
||||
stride_02: u32,
|
||||
stride_12: u32,
|
||||
stride_03: u32,
|
||||
stride_13: u32,
|
||||
|
||||
bs02: u32,
|
||||
bs03: u32,
|
||||
broadcast2: u32,
|
||||
broadcast3: u32
|
||||
};
|
||||
|
||||
@group(0) @binding(0) var<storage, read_write> src0: array<f32>; // N rows, K columns
|
||||
@group(0) @binding(1) var<storage, read_write> src1: array<f32>; // M rows, K columns (transposed)
|
||||
@group(0) @binding(2) var<storage, read_write> dst: array<f32>; // M rows, N columns
|
||||
|
||||
@group(0) @binding(3) var<uniform> params: MulMatParams;
|
||||
|
||||
@compute @workgroup_size(64)
|
||||
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
|
||||
let total = params.m * params.n * params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3;
|
||||
if (global_id.x >= total) {
|
||||
return;
|
||||
}
|
||||
|
||||
let dst2_stride = params.m * params.n;
|
||||
let dst3_stride = dst2_stride * params.bs02 * params.broadcast2;
|
||||
|
||||
let dst3_idx = global_id.x / dst3_stride;
|
||||
let src03_idx = dst3_idx / params.broadcast3; // src0 may be broadcast along the third dimension
|
||||
let src13_idx = dst3_idx; // src1 is not broadcast
|
||||
let dst3_rem = global_id.x % dst3_stride;
|
||||
|
||||
let dst2_idx = dst3_rem / dst2_stride;
|
||||
let src02_idx = dst2_idx / params.broadcast2; // src0 may also be broadcast along the second dimension
|
||||
let src12_idx = dst2_idx; // src1 is not broadcast
|
||||
|
||||
let dst2_rem = dst3_rem % dst2_stride;
|
||||
|
||||
let row = dst2_rem / params.n; // output row
|
||||
let col = dst2_rem % params.n; // output column
|
||||
|
||||
var sum = 0.0;
|
||||
for (var i: u32 = 0u; i < params.k; i = i + 1u) {
|
||||
let src0_idx = src03_idx * params.stride_03 + src02_idx * params.stride_02 + col * params.stride_01 + i;
|
||||
let src1_idx = src13_idx * params.stride_13 + src12_idx * params.stride_12 + row * params.stride_11 + i;
|
||||
sum = sum + src0[src0_idx] * src1[src1_idx];
|
||||
}
|
||||
dst[dst3_idx * dst3_stride + dst2_idx * dst2_stride + row * params.n + col] = sum;
|
||||
}
|
|
@ -367,6 +367,7 @@ class MODEL_ARCH(IntEnum):
|
|||
HUNYUAN_MOE = auto()
|
||||
SMOLLM3 = auto()
|
||||
LFM2 = auto()
|
||||
DREAM = auto()
|
||||
|
||||
|
||||
class VISION_PROJECTOR_TYPE(IntEnum):
|
||||
|
@ -683,6 +684,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
|||
MODEL_ARCH.HUNYUAN_MOE: "hunyuan-moe",
|
||||
MODEL_ARCH.SMOLLM3: "smollm3",
|
||||
MODEL_ARCH.LFM2: "lfm2",
|
||||
MODEL_ARCH.DREAM: "dream",
|
||||
}
|
||||
|
||||
VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
|
||||
|
@ -1289,6 +1291,21 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
|||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
],
|
||||
MODEL_ARCH.DREAM: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ROPE_FREQS,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.FFN_GATE,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
],
|
||||
MODEL_ARCH.QWEN2VL: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
|
|
|
@ -338,6 +338,9 @@ extern "C" {
|
|||
bool swa_full; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
|
||||
// NOTE: setting to false when n_seq_max > 1 can cause bad performance in some cases
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573
|
||||
bool kv_unified; // use a unified buffer across the input sequences when computing the attention
|
||||
// try to disable when n_seq_max > 1 for improved performance when the sequences do not share a large prefix
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/14363
|
||||
};
|
||||
|
||||
// model quantization parameters
|
||||
|
@ -728,7 +731,7 @@ extern "C" {
|
|||
// - lazily on next llama_decode()
|
||||
// p0 < 0 : [0, p1]
|
||||
// p1 < 0 : [p0, inf)
|
||||
DEPRECATED(void llama_kv_self_seq_div(
|
||||
DEPRECATED(LLAMA_API void llama_kv_self_seq_div(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
|
@ -1008,6 +1011,7 @@ extern "C" {
|
|||
LLAMA_API llama_token llama_vocab_sep(const struct llama_vocab * vocab); // sentence separator
|
||||
LLAMA_API llama_token llama_vocab_nl (const struct llama_vocab * vocab); // next-line
|
||||
LLAMA_API llama_token llama_vocab_pad(const struct llama_vocab * vocab); // padding
|
||||
LLAMA_API llama_token llama_vocab_mask(const struct llama_vocab * vocab); // mask
|
||||
|
||||
LLAMA_API bool llama_vocab_get_add_bos(const struct llama_vocab * vocab);
|
||||
LLAMA_API bool llama_vocab_get_add_eos(const struct llama_vocab * vocab);
|
||||
|
|
|
@ -85,6 +85,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
|||
{ LLM_ARCH_HUNYUAN_MOE, "hunyuan-moe" },
|
||||
{ LLM_ARCH_SMOLLM3, "smollm3" },
|
||||
{ LLM_ARCH_LFM2, "lfm2" },
|
||||
{ LLM_ARCH_DREAM, "dream" },
|
||||
{ LLM_ARCH_UNKNOWN, "(unknown)" },
|
||||
};
|
||||
|
||||
|
@ -1891,6 +1892,23 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
|||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_DREAM,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
||||
|
@ -2133,3 +2151,12 @@ bool llm_arch_is_hybrid(const llm_arch & arch) {
|
|||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool llm_arch_is_diffusion(const llm_arch & arch) {
|
||||
switch (arch) {
|
||||
case LLM_ARCH_DREAM:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -89,6 +89,7 @@ enum llm_arch {
|
|||
LLM_ARCH_HUNYUAN_MOE,
|
||||
LLM_ARCH_SMOLLM3,
|
||||
LLM_ARCH_LFM2,
|
||||
LLM_ARCH_DREAM,
|
||||
LLM_ARCH_UNKNOWN,
|
||||
};
|
||||
|
||||
|
@ -479,3 +480,4 @@ const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor);
|
|||
|
||||
bool llm_arch_is_recurrent(const llm_arch & arch);
|
||||
bool llm_arch_is_hybrid (const llm_arch & arch);
|
||||
bool llm_arch_is_diffusion(const llm_arch & arch);
|
||||
|
|
|
@ -27,6 +27,7 @@ bool llama_batch_allocr::init(
|
|||
const llama_vocab & vocab,
|
||||
const llama_memory_i * memory,
|
||||
uint32_t n_embd,
|
||||
uint32_t n_seq_max,
|
||||
bool output_all) {
|
||||
clear();
|
||||
|
||||
|
@ -40,6 +41,11 @@ bool llama_batch_allocr::init(
|
|||
// validate input batch
|
||||
//
|
||||
|
||||
if (n_seq_max > LLAMA_MAX_SEQ) {
|
||||
LLAMA_LOG_ERROR("%s: n_seq_max = %d > %d\n", __func__, n_seq_max, LLAMA_MAX_SEQ);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (batch.token) {
|
||||
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
||||
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= vocab.n_tokens()) {
|
||||
|
@ -52,8 +58,8 @@ bool llama_batch_allocr::init(
|
|||
if (batch.seq_id) {
|
||||
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
||||
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
|
||||
if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= LLAMA_MAX_SEQ)) {
|
||||
LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], LLAMA_MAX_SEQ);
|
||||
if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= (llama_seq_id) n_seq_max)) {
|
||||
LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], (llama_seq_id) n_seq_max);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
@ -86,7 +92,7 @@ bool llama_batch_allocr::init(
|
|||
|
||||
// initialize the starting position for each sequence based on the positions in the memory
|
||||
llama_pos p0[LLAMA_MAX_SEQ];
|
||||
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
||||
for (uint32_t s = 0; s < n_seq_max; ++s) {
|
||||
if (!memory) {
|
||||
// if no memory -> start from 0
|
||||
p0[s] = 0;
|
||||
|
@ -143,7 +149,8 @@ bool llama_batch_allocr::init(
|
|||
// compute stats
|
||||
//
|
||||
|
||||
this->n_embd = n_embd;
|
||||
this->n_embd = n_embd;
|
||||
this->n_seq_max = n_seq_max;
|
||||
|
||||
// count the outputs in this batch
|
||||
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
||||
|
@ -189,7 +196,7 @@ bool llama_batch_allocr::init(
|
|||
seq_set_map[cur].push_back(i);
|
||||
}
|
||||
|
||||
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
||||
for (uint32_t s = 0; s < n_seq_max; ++s) {
|
||||
if (seq_set_unq.test(s)) {
|
||||
seq_idx[s] = seq_id_unq.size();
|
||||
seq_id_unq.push_back(s);
|
||||
|
@ -241,7 +248,7 @@ bool llama_batch_allocr::init(
|
|||
// consistency checks
|
||||
//
|
||||
|
||||
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
||||
for (uint32_t s = 0; s < n_seq_max; ++s) {
|
||||
if (seq_pos[s].empty()) {
|
||||
continue;
|
||||
}
|
||||
|
@ -283,8 +290,8 @@ bool llama_batch_allocr::init(
|
|||
}
|
||||
|
||||
if (memory) {
|
||||
for (int32_t s0 = 0; s0 < LLAMA_MAX_SEQ; ++s0) {
|
||||
for (int32_t s1 = 0; s1 < LLAMA_MAX_SEQ; ++s1) {
|
||||
for (uint32_t s0 = 0; s0 < n_seq_max; ++s0) {
|
||||
for (uint32_t s1 = 0; s1 < n_seq_max; ++s1) {
|
||||
if (seq_cpl[s0][s1]) {
|
||||
if (memory->seq_pos_min(s0) != memory->seq_pos_min(s1) ||
|
||||
memory->seq_pos_max(s0) != memory->seq_pos_max(s1)) {
|
||||
|
@ -315,12 +322,12 @@ bool llama_batch_allocr::init(
|
|||
//
|
||||
{
|
||||
seq_set_t cur_seq_set[LLAMA_MAX_SEQ];
|
||||
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
||||
for (uint32_t s = 0; s < n_seq_max; ++s) {
|
||||
cur_seq_set[s].set();
|
||||
}
|
||||
|
||||
llama_pos cur_seq_pos[LLAMA_MAX_SEQ];
|
||||
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
||||
for (uint32_t s = 0; s < n_seq_max; ++s) {
|
||||
cur_seq_pos[s] = -1;
|
||||
}
|
||||
|
||||
|
@ -691,7 +698,7 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
|
|||
}
|
||||
}
|
||||
|
||||
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
||||
for (uint32_t s = 0; s < n_seq_max; ++s) {
|
||||
if (seq_set_unq.test(s)) {
|
||||
ubatch.seq_idx[s] = ubatch.seq_id_unq.size();
|
||||
ubatch.seq_id_unq.push_back(s);
|
||||
|
|
|
@ -48,6 +48,7 @@ public:
|
|||
const llama_vocab & vocab,
|
||||
const llama_memory_i * memory,
|
||||
uint32_t n_embd,
|
||||
uint32_t n_seq_max,
|
||||
bool output_all);
|
||||
|
||||
const llama_batch & get_batch() const;
|
||||
|
@ -100,6 +101,7 @@ private:
|
|||
const uint32_t n_pos_per_embd;
|
||||
|
||||
uint32_t n_embd;
|
||||
uint32_t n_seq_max;
|
||||
uint32_t n_outputs;
|
||||
|
||||
std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id
|
||||
|
|
|
@ -98,10 +98,20 @@ llama_context::llama_context(
|
|||
LLAMA_LOG_WARN("%s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n", __func__, GGML_KQ_MASK_PAD);
|
||||
cparams.n_batch = GGML_KQ_MASK_PAD;
|
||||
}
|
||||
|
||||
cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
|
||||
|
||||
cparams.op_offload = params.op_offload;
|
||||
cparams.kv_unified = params.kv_unified;
|
||||
|
||||
{
|
||||
const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS");
|
||||
const bool supports_set_rows = LLAMA_SET_ROWS ? atoi(LLAMA_SET_ROWS) : 0;
|
||||
|
||||
if (!supports_set_rows && !cparams.kv_unified) {
|
||||
LLAMA_LOG_WARN("%s: non-unified KV cache requires ggml_set_rows() - forcing unified KV cache\n", __func__);
|
||||
cparams.kv_unified = true;
|
||||
}
|
||||
}
|
||||
|
||||
const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
|
||||
|
||||
|
@ -112,6 +122,7 @@ llama_context::llama_context(
|
|||
LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
|
||||
LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn);
|
||||
LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn);
|
||||
LLAMA_LOG_INFO("%s: kv_unified = %s\n", __func__, cparams.kv_unified ? "true" : "false");
|
||||
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
|
||||
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);
|
||||
|
||||
|
@ -267,7 +278,7 @@ llama_context::llama_context(
|
|||
|
||||
// reserve worst-case graph
|
||||
if (!hparams.vocab_only && memory) {
|
||||
const uint32_t n_seqs = cparams.n_seq_max;
|
||||
const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
|
||||
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
||||
|
||||
LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
|
||||
|
@ -300,7 +311,7 @@ llama_context::llama_context(
|
|||
|
||||
// reserve with tg graph to get the number of splits and nodes
|
||||
{
|
||||
auto * gf = graph_reserve(1, 1, 1, mctx.get());
|
||||
auto * gf = graph_reserve(n_seqs, n_seqs, n_seqs, mctx.get());
|
||||
if (!gf) {
|
||||
throw std::runtime_error("failed to allocate compute tg buffers");
|
||||
}
|
||||
|
@ -311,6 +322,10 @@ llama_context::llama_context(
|
|||
|
||||
// reserve again with pp graph to avoid ggml-alloc reallocations during inference
|
||||
{
|
||||
// TODO: not sure if the following graph would be worster case for multi-stream KV caches:
|
||||
//
|
||||
// auto * gf = graph_reserve(n_tokens, 1, n_tokens, mctx.get());
|
||||
//
|
||||
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
|
||||
if (!gf) {
|
||||
throw std::runtime_error("failed to allocate compute pp buffers");
|
||||
|
@ -475,7 +490,7 @@ bool llama_context::kv_self_update(bool optimize) {
|
|||
throw std::runtime_error("failed to initialize memory context");
|
||||
}
|
||||
|
||||
const uint32_t n_seqs = cparams.n_seq_max;
|
||||
const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
|
||||
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
||||
|
||||
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
|
||||
|
@ -735,13 +750,15 @@ int llama_context::encode(const llama_batch & batch_inp) {
|
|||
const int32_t n_vocab = model.vocab.n_tokens();
|
||||
|
||||
// note: during encode, we always pass the full sequence starting from pos = 0
|
||||
if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, true)) {
|
||||
if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
|
||||
return -1;
|
||||
}
|
||||
|
||||
const uint32_t n_tokens = balloc->get_n_tokens();
|
||||
|
||||
// [TAG_NO_CACHE_PAD]
|
||||
// TODO: add new split mode where we pad the input sequences so that ubatch.equal_seqs == true
|
||||
const llama_ubatch ubatch = balloc->split_simple(n_tokens);
|
||||
|
||||
// micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
|
||||
|
@ -910,7 +927,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||
// when computing embeddings, all tokens are output
|
||||
const bool output_all = cparams.embeddings;
|
||||
|
||||
if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, output_all)) {
|
||||
if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, output_all)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
|
||||
return -1;
|
||||
}
|
||||
|
@ -2039,7 +2056,7 @@ void llama_context::opt_epoch_iter(
|
|||
batch.logits [pos_batch] = true;
|
||||
}
|
||||
|
||||
if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd, true)) {
|
||||
if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
|
||||
return;
|
||||
}
|
||||
|
@ -2198,6 +2215,7 @@ llama_context_params llama_context_default_params() {
|
|||
/*.no_perf =*/ true,
|
||||
/*.op_offload =*/ true,
|
||||
/*.swa_full =*/ true,
|
||||
/*.kv_unified =*/ false,
|
||||
};
|
||||
|
||||
return result;
|
||||
|
|
|
@ -11,8 +11,8 @@ struct llama_cparams {
|
|||
uint32_t n_batch;
|
||||
uint32_t n_ubatch;
|
||||
uint32_t n_seq_max;
|
||||
int n_threads; // number of threads to use for generation
|
||||
int n_threads_batch; // number of threads to use for batch processing
|
||||
int32_t n_threads; // number of threads to use for generation
|
||||
int32_t n_threads_batch; // number of threads to use for batch processing
|
||||
|
||||
float rope_freq_base;
|
||||
float rope_freq_scale;
|
||||
|
@ -33,6 +33,7 @@ struct llama_cparams {
|
|||
bool no_perf;
|
||||
bool warmup;
|
||||
bool op_offload;
|
||||
bool kv_unified;
|
||||
|
||||
enum llama_pooling_type pooling_type;
|
||||
|
||||
|
|
|
@ -982,13 +982,16 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
|||
float kq_scale) const {
|
||||
const bool v_trans = v->nb[1] > v->nb[2];
|
||||
|
||||
// split the batch into streams if needed
|
||||
const auto n_stream = k->ne[3];
|
||||
|
||||
q = ggml_reshape_4d(ctx0, q, q->ne[0], q->ne[1], q->ne[2]/n_stream, n_stream);
|
||||
|
||||
q = ggml_permute(ctx0, q, 0, 2, 1, 3);
|
||||
k = ggml_permute(ctx0, k, 0, 2, 1, 3);
|
||||
v = ggml_permute(ctx0, v, 0, 2, 1, 3);
|
||||
|
||||
const auto n_tokens = q->ne[1];
|
||||
const auto n_head = q->ne[2];
|
||||
const auto n_kv = k->ne[1];
|
||||
const auto n_kv = k->ne[1];
|
||||
|
||||
ggml_tensor * cur;
|
||||
|
||||
|
@ -1030,7 +1033,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
|||
#endif
|
||||
}
|
||||
|
||||
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
|
||||
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
|
||||
} else {
|
||||
ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
|
||||
|
||||
|
@ -1075,7 +1078,8 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
|||
|
||||
cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
|
||||
|
||||
cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
|
||||
// recombine streams
|
||||
cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
|
||||
|
||||
if (!cparams.offload_kqv) {
|
||||
// all nodes between the KV store and the attention output are run on the CPU
|
||||
|
@ -1122,6 +1126,10 @@ ggml_tensor * llm_graph_context::build_attn(
|
|||
|
||||
const auto & kq_mask = inp->get_kq_mask();
|
||||
|
||||
// [TAG_NO_CACHE_PAD]
|
||||
// TODO: if ubatch.equal_seqs == true, we can split the three tensors below into ubatch.n_seqs_unq streams
|
||||
assert(ubatch.equal_seqs == false);
|
||||
|
||||
ggml_tensor * q = q_cur;
|
||||
ggml_tensor * k = k_cur;
|
||||
ggml_tensor * v = v_cur;
|
||||
|
@ -1156,13 +1164,14 @@ static std::unique_ptr<llm_graph_input_attn_kv_unified> build_attn_inp_kv_unifie
|
|||
{
|
||||
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
|
||||
|
||||
const auto n_kv = mctx_cur->get_n_kv();
|
||||
const auto n_kv = mctx_cur->get_n_kv();
|
||||
const auto n_tokens = ubatch.n_tokens;
|
||||
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
|
||||
|
||||
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
|
||||
inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
|
||||
|
||||
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
|
||||
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
|
||||
ggml_set_input(inp->self_kq_mask);
|
||||
|
||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||
|
@ -1362,13 +1371,15 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
|
|||
|
||||
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
|
||||
|
||||
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
|
||||
|
||||
{
|
||||
const auto n_kv = mctx_cur->get_base()->get_n_kv();
|
||||
|
||||
inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
|
||||
inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
|
||||
|
||||
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
|
||||
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
|
||||
ggml_set_input(inp->self_kq_mask);
|
||||
|
||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||
|
@ -1382,7 +1393,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
|
|||
inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
|
||||
inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
|
||||
|
||||
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
|
||||
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
|
||||
ggml_set_input(inp->self_kq_mask_swa);
|
||||
|
||||
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
|
||||
|
|
|
@ -255,10 +255,10 @@ public:
|
|||
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
|
||||
|
||||
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
|
||||
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
|
||||
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
|
||||
|
||||
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch, 1, 1]
|
||||
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch, 1, 1]
|
||||
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
|
||||
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
|
||||
|
||||
const llama_hparams & hparams;
|
||||
const llama_cparams & cparams;
|
||||
|
@ -289,14 +289,14 @@ public:
|
|||
ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
|
||||
|
||||
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
|
||||
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
|
||||
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
|
||||
ggml_tensor * self_k_idxs_swa = nullptr; // I64 [n_batch]
|
||||
ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch]
|
||||
ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
|
||||
|
||||
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch, 1, 1]
|
||||
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch, 1, 1]
|
||||
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch, 1, 1]
|
||||
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch, 1, 1]
|
||||
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
|
||||
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
|
||||
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
|
||||
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
|
||||
|
||||
const llama_hparams & hparams;
|
||||
const llama_cparams & cparams;
|
||||
|
|
|
@ -65,6 +65,46 @@ uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const {
|
|||
return n_embd_head_v * n_head_kv;
|
||||
}
|
||||
|
||||
bool llama_hparams::is_n_embd_k_gqa_variable() const {
|
||||
const uint32_t val = n_embd_k_gqa();
|
||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||
if (val != n_embd_k_gqa(il)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
bool llama_hparams::is_n_embd_v_gqa_variable() const {
|
||||
const uint32_t val = n_embd_v_gqa();
|
||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||
if (val != n_embd_v_gqa(il)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
uint32_t llama_hparams::n_embd_k_gqa_max() const {
|
||||
uint32_t val = n_embd_k_gqa();
|
||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||
val = std::max(val, n_embd_k_gqa(il));
|
||||
}
|
||||
|
||||
return val;
|
||||
}
|
||||
|
||||
uint32_t llama_hparams::n_embd_v_gqa_max() const {
|
||||
uint32_t val = n_embd_v_gqa();
|
||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||
val = std::max(val, n_embd_v_gqa(il));
|
||||
}
|
||||
|
||||
return val;
|
||||
}
|
||||
|
||||
uint32_t llama_hparams::n_embd_r() const {
|
||||
if (wkv_head_size != 0) {
|
||||
// for RWKV models
|
||||
|
|
|
@ -191,6 +191,14 @@ struct llama_hparams {
|
|||
// dimension of value embeddings across all k-v heads
|
||||
uint32_t n_embd_v_gqa(uint32_t il = 0) const;
|
||||
|
||||
// true if any layer has a different n_embd_k_gqa/n_embd_v_gqa
|
||||
bool is_n_embd_k_gqa_variable() const;
|
||||
bool is_n_embd_v_gqa_variable() const;
|
||||
|
||||
// return the maximum n_embd_k_gqa/n_embd_v_gqa across all layers
|
||||
uint32_t n_embd_k_gqa_max() const;
|
||||
uint32_t n_embd_v_gqa_max() const;
|
||||
|
||||
// dimension of the rolling state embeddings
|
||||
// corresponds to Mamba's conv_states size or RWKV's token_shift states size
|
||||
uint32_t n_embd_r() const;
|
||||
|
|
|
@ -18,16 +18,17 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
|
|||
bool v_trans,
|
||||
bool offload,
|
||||
bool swa_full,
|
||||
bool unified,
|
||||
uint32_t kv_size,
|
||||
uint32_t n_seq_max,
|
||||
uint32_t n_ubatch,
|
||||
uint32_t n_pad) : hparams(model.hparams) {
|
||||
uint32_t n_pad) : hparams(model.hparams), unified(unified) {
|
||||
llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
|
||||
llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); };
|
||||
|
||||
const uint32_t size_base = kv_size;
|
||||
|
||||
uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_ubatch, n_pad));
|
||||
uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*(unified ? n_seq_max : 1) + n_ubatch, n_pad));
|
||||
|
||||
//kcpp: pad the swa kv cache as well, similar to extra_context_handle_fragmentation
|
||||
size_swa += 32;
|
||||
|
@ -45,14 +46,14 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
|
|||
|
||||
kv_base = std::make_unique<llama_kv_cache_unified>(
|
||||
model, std::move(filter_base), type_k, type_v,
|
||||
v_trans, offload, size_base, n_seq_max, n_pad,
|
||||
v_trans, offload, unified, size_base, n_seq_max, n_pad,
|
||||
0, LLAMA_SWA_TYPE_NONE);
|
||||
|
||||
LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
|
||||
|
||||
kv_swa = std::make_unique<llama_kv_cache_unified>(
|
||||
model, std::move(filter_swa), type_k, type_v,
|
||||
v_trans, offload, size_swa, n_seq_max, n_pad,
|
||||
v_trans, offload, unified, size_swa, n_seq_max, n_pad,
|
||||
hparams.n_swa, hparams.swa_type);
|
||||
}
|
||||
|
||||
|
@ -104,6 +105,11 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
|
|||
|
||||
// first try simple split
|
||||
do {
|
||||
if (!unified) {
|
||||
// requires equal splits, so we skip the simple split
|
||||
break;
|
||||
}
|
||||
|
||||
balloc.split_reset();
|
||||
|
||||
std::vector<llama_ubatch> ubatches;
|
||||
|
@ -144,7 +150,7 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
|
|||
|
||||
std::vector<llama_ubatch> ubatches;
|
||||
while (true) {
|
||||
auto ubatch = balloc.split_equal(n_ubatch, false);
|
||||
auto ubatch = balloc.split_equal(n_ubatch, !unified);
|
||||
|
||||
if (ubatch.n_tokens == 0) {
|
||||
break;
|
||||
|
|
|
@ -20,6 +20,7 @@ public:
|
|||
bool v_trans,
|
||||
bool offload,
|
||||
bool swa_full,
|
||||
bool unified,
|
||||
uint32_t kv_size,
|
||||
uint32_t n_seq_max,
|
||||
uint32_t n_ubatch,
|
||||
|
@ -68,6 +69,8 @@ public:
|
|||
private:
|
||||
const llama_hparams & hparams;
|
||||
|
||||
const bool unified;
|
||||
|
||||
std::unique_ptr<llama_kv_cache_unified> kv_base;
|
||||
std::unique_ptr<llama_kv_cache_unified> kv_swa;
|
||||
};
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -35,16 +35,50 @@ public:
|
|||
std::vector<uint32_t> ids;
|
||||
};
|
||||
|
||||
struct stream_copy_info {
|
||||
bool empty() const {
|
||||
assert(ssrc.size() == sdst.size());
|
||||
return ssrc.empty();
|
||||
}
|
||||
|
||||
std::vector<uint32_t> ssrc;
|
||||
std::vector<uint32_t> sdst;
|
||||
};
|
||||
|
||||
// for each ubatch, create a slot_info that contains information about where the ubatch should be inserted in the
|
||||
// KV cells. for example, cell indices for each token, such that: token[i] -> goes to cells[idxs[i]]
|
||||
struct slot_info {
|
||||
// data for ggml_set_rows
|
||||
using idx_vec_t = std::vector<uint32_t>;
|
||||
|
||||
idx_vec_t idxs;
|
||||
// number of streams: ns = s1 - s0 + 1
|
||||
llama_seq_id s0;
|
||||
llama_seq_id s1;
|
||||
|
||||
std::vector<llama_seq_id> strm; // [ns]
|
||||
std::vector<idx_vec_t> idxs; // [ns]
|
||||
|
||||
uint32_t head() const {
|
||||
return idxs.at(0);
|
||||
GGML_ASSERT(idxs.size() == 1);
|
||||
GGML_ASSERT(!idxs[0].empty());
|
||||
|
||||
return idxs[0][0];
|
||||
}
|
||||
|
||||
void resize(size_t n) {
|
||||
strm.resize(n);
|
||||
idxs.resize(n);
|
||||
}
|
||||
|
||||
size_t size() const {
|
||||
GGML_ASSERT(idxs.size() == strm.size());
|
||||
GGML_ASSERT(!idxs.empty());
|
||||
|
||||
return idxs[0].size();
|
||||
}
|
||||
|
||||
size_t n_stream() const {
|
||||
return strm.size();
|
||||
}
|
||||
|
||||
bool empty() const {
|
||||
|
@ -54,9 +88,6 @@ public:
|
|||
void clear() {
|
||||
idxs.clear();
|
||||
}
|
||||
|
||||
// TODO: implement
|
||||
//std::vector<idx_vec_t> seq_idxs;
|
||||
};
|
||||
|
||||
using slot_info_vec_t = std::vector<slot_info>;
|
||||
|
@ -68,6 +99,7 @@ public:
|
|||
ggml_type type_v,
|
||||
bool v_trans,
|
||||
bool offload,
|
||||
bool unified,
|
||||
uint32_t kv_size,
|
||||
uint32_t n_seq_max,
|
||||
uint32_t n_pad,
|
||||
|
@ -111,7 +143,8 @@ public:
|
|||
// llama_kv_cache_unified specific API
|
||||
//
|
||||
|
||||
uint32_t get_size() const;
|
||||
uint32_t get_size() const;
|
||||
uint32_t get_n_stream() const;
|
||||
|
||||
bool get_has_shift() const;
|
||||
|
||||
|
@ -122,8 +155,8 @@ public:
|
|||
uint32_t get_n_kv() const;
|
||||
|
||||
// get views of the current state of the cache
|
||||
ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv) const;
|
||||
ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const;
|
||||
ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const;
|
||||
ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const;
|
||||
|
||||
// store k_cur and v_cur in the cache based on the provided head location
|
||||
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const;
|
||||
|
@ -137,7 +170,7 @@ public:
|
|||
// return empty vector on failure
|
||||
slot_info_vec_t prepare(const std::vector<llama_ubatch> & ubatches);
|
||||
|
||||
bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo);
|
||||
bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo, const stream_copy_info & sc_info);
|
||||
|
||||
// find a slot of kv cells that can hold the ubatch
|
||||
// if cont == true, then the slot must be continuous
|
||||
|
@ -157,8 +190,9 @@ public:
|
|||
void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
|
||||
void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
|
||||
|
||||
void set_input_k_shift(ggml_tensor * dst) const;
|
||||
|
||||
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
|
||||
void set_input_k_shift (ggml_tensor * dst) const;
|
||||
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
||||
|
||||
private:
|
||||
|
@ -172,15 +206,15 @@ private:
|
|||
|
||||
ggml_tensor * k;
|
||||
ggml_tensor * v;
|
||||
|
||||
std::vector<ggml_tensor *> k_stream;
|
||||
std::vector<ggml_tensor *> v_stream;
|
||||
};
|
||||
|
||||
bool v_trans = true; // the value tensor is transposed
|
||||
|
||||
// the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
|
||||
// note: this is not part of the KV state and it's only used to speed-up the find_slot() method
|
||||
uint32_t head = 0;
|
||||
|
||||
const uint32_t n_seq_max = 1;
|
||||
const uint32_t n_stream = 1;
|
||||
|
||||
// required padding
|
||||
const uint32_t n_pad = 1;
|
||||
|
@ -200,7 +234,17 @@ private:
|
|||
std::vector<ggml_context_ptr> ctxs;
|
||||
std::vector<ggml_backend_buffer_ptr> bufs;
|
||||
|
||||
llama_kv_cells_unified cells;
|
||||
// the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
|
||||
// note: this is not part of the KV state and it's only used to speed-up the find_slot() method
|
||||
std::vector<uint32_t> v_heads;
|
||||
|
||||
std::vector<llama_kv_cells_unified> v_cells;
|
||||
|
||||
// maps from a sequence id to a stream id
|
||||
std::vector<uint32_t> seq_to_stream;
|
||||
|
||||
// pending stream copies that will be applied during the next update
|
||||
stream_copy_info sc_info;
|
||||
|
||||
std::vector<kv_layer> layers;
|
||||
|
||||
|
@ -237,18 +281,25 @@ private:
|
|||
ggml_cgraph * gf,
|
||||
const defrag_info & dinfo) const;
|
||||
|
||||
void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
|
||||
void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
|
||||
struct cell_ranges_t {
|
||||
uint32_t strm;
|
||||
|
||||
bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
|
||||
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
|
||||
std::vector<std::pair<uint32_t, uint32_t>> data; // ranges, from inclusive, to exclusive
|
||||
};
|
||||
|
||||
void state_write_meta(llama_io_write_i & io, const cell_ranges_t & cr, llama_seq_id seq_id = -1) const;
|
||||
void state_write_data(llama_io_write_i & io, const cell_ranges_t & cr) const;
|
||||
|
||||
bool state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
|
||||
bool state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count);
|
||||
};
|
||||
|
||||
class llama_kv_cache_unified_context : public llama_memory_context_i {
|
||||
public:
|
||||
// some shorthands
|
||||
using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
|
||||
using defrag_info = llama_kv_cache_unified::defrag_info;
|
||||
using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
|
||||
using defrag_info = llama_kv_cache_unified::defrag_info;
|
||||
using stream_copy_info = llama_kv_cache_unified::stream_copy_info;
|
||||
|
||||
// used for errors
|
||||
llama_kv_cache_unified_context(llama_memory_status status);
|
||||
|
@ -262,7 +313,8 @@ public:
|
|||
llama_kv_cache_unified * kv,
|
||||
llama_context * lctx,
|
||||
bool do_shift,
|
||||
defrag_info dinfo);
|
||||
defrag_info dinfo,
|
||||
stream_copy_info sc_info);
|
||||
|
||||
// used to create a batch procesing context from a batch
|
||||
llama_kv_cache_unified_context(
|
||||
|
@ -320,6 +372,8 @@ private:
|
|||
|
||||
defrag_info dinfo;
|
||||
|
||||
stream_copy_info sc_info;
|
||||
|
||||
//
|
||||
// batch processing context
|
||||
//
|
||||
|
|
|
@ -40,6 +40,7 @@ llama_memory_hybrid::llama_memory_hybrid(
|
|||
offload,
|
||||
kv_size,
|
||||
n_seq_max,
|
||||
1,
|
||||
n_pad,
|
||||
n_swa,
|
||||
swa_type
|
||||
|
|
|
@ -854,6 +854,21 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_DREAM:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
// Dream models are primarily 7B with 28 layers
|
||||
switch (hparams.n_layer) {
|
||||
case 28:
|
||||
type = LLM_TYPE_7B;
|
||||
break;
|
||||
default:
|
||||
type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
// Set non-causal attention for diffusion models
|
||||
hparams.causal_attn = false;
|
||||
}
|
||||
break;
|
||||
case LLM_ARCH_QWEN2MOE:
|
||||
{
|
||||
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false);
|
||||
|
@ -2766,12 +2781,14 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
|||
} break;
|
||||
case LLM_ARCH_QWEN2:
|
||||
case LLM_ARCH_QWEN2VL:
|
||||
case LLM_ARCH_DREAM:
|
||||
{
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
|
||||
// output
|
||||
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
||||
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
|
||||
output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), {n_vocab}, TENSOR_NOT_REQUIRED);
|
||||
// if output is NULL, init from the input tok embed
|
||||
if (output == NULL) {
|
||||
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
|
||||
|
@ -7849,6 +7866,113 @@ struct llm_build_qwen2 : public llm_graph_context {
|
|||
// lm_head
|
||||
cur = build_lora_mm(model.output, cur);
|
||||
|
||||
if (model.output_b != nullptr) {
|
||||
cur = ggml_add(ctx0, cur, model.output_b);
|
||||
}
|
||||
|
||||
cb(cur, "result_output", -1);
|
||||
res->t_logits = cur;
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
}
|
||||
};
|
||||
|
||||
struct llm_build_dream : public llm_graph_context {
|
||||
llm_build_dream(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) :
|
||||
llm_graph_context(params) {
|
||||
//copied from qwen2
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
inpL = build_inp_embd(model.tok_embd);
|
||||
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_no_cache();
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
ggml_tensor * inpSA = inpL;
|
||||
|
||||
// norm
|
||||
cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
|
||||
// self-attention
|
||||
{
|
||||
// compute Q and K and RoPE them
|
||||
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
||||
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
|
||||
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
|
||||
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
|
||||
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
|
||||
Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
|
||||
cb(Qcur, "Qcur", il);
|
||||
cb(Kcur, "Kcur", il);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr,
|
||||
nullptr, 1.0f / sqrtf(float(n_embd_head)), il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||
}
|
||||
|
||||
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||
cb(ffn_inp, "ffn_inp", il);
|
||||
|
||||
// feed-forward network
|
||||
cur = build_norm(ffn_inp, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
cur = build_ffn(cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL,
|
||||
model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
|
||||
cur = build_cvec(cur, il);
|
||||
cb(cur, "l_out", il);
|
||||
|
||||
// input for next layer
|
||||
inpL = cur;
|
||||
}
|
||||
|
||||
cur = inpL;
|
||||
|
||||
cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1);
|
||||
|
||||
cb(cur, "result_norm", -1);
|
||||
res->t_embd = cur;
|
||||
|
||||
// lm_head
|
||||
cur = build_lora_mm(model.output, cur);
|
||||
|
||||
cb(cur, "result_output", -1);
|
||||
res->t_logits = cur;
|
||||
|
||||
|
@ -15863,6 +15987,7 @@ private:
|
|||
cb(zx, "mamba_in_proj", il);
|
||||
// {8192, 5, 1, 1} -> {8192, 1, 5, 1}
|
||||
zx = ggml_permute(ctx0, zx, 0, 2, 1, 3);
|
||||
zx = ggml_cont(ctx0, zx);
|
||||
zx = ggml_reshape_4d(ctx0, zx, head_dim * 2, n_heads, n_seq_tokens, n_seqs);
|
||||
cb(zx, "mamba_in_proj_out", il);
|
||||
|
||||
|
@ -15880,7 +16005,6 @@ private:
|
|||
// conv1d
|
||||
{
|
||||
// => {d_conv - 1 + n_seq_tokens, d_inner, n_seqs}
|
||||
x = ggml_view_2d(ctx0, x, d_inner, n_seq_tokens * n_seqs, d_inner * x->nb[0], 0);
|
||||
ggml_tensor * conv_x = ggml_concat(ctx0, conv, ggml_transpose(ctx0, x), 0);
|
||||
cb(conv_x, "mamba_conv1d_input", il);
|
||||
|
||||
|
@ -16587,6 +16711,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
|||
case LLM_ARCH_NOMIC_BERT_MOE:
|
||||
case LLM_ARCH_NEO_BERT:
|
||||
case LLM_ARCH_WAVTOKENIZER_DEC:
|
||||
case LLM_ARCH_DREAM:
|
||||
{
|
||||
res = nullptr;
|
||||
} break;
|
||||
|
@ -16627,7 +16752,18 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
|||
} else {
|
||||
const auto padding = llama_kv_cache_unified::get_padding(cparams);
|
||||
|
||||
cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
|
||||
uint32_t n_ctx_per_stream = cparams.n_ctx;
|
||||
|
||||
if (!cparams.kv_unified) {
|
||||
n_ctx_per_stream = (cparams.n_ctx + cparams.n_seq_max - 1)/cparams.n_seq_max;
|
||||
n_ctx_per_stream = GGML_PAD(n_ctx_per_stream, padding);
|
||||
|
||||
cparams.n_ctx = n_ctx_per_stream*cparams.n_seq_max;
|
||||
} else {
|
||||
n_ctx_per_stream = GGML_PAD(n_ctx_per_stream, padding);
|
||||
|
||||
cparams.n_ctx = n_ctx_per_stream;
|
||||
}
|
||||
|
||||
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
|
||||
|
||||
|
@ -16641,7 +16777,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
|||
!cparams.flash_attn,
|
||||
cparams.offload_kqv,
|
||||
params.swa_full,
|
||||
cparams.n_ctx,
|
||||
cparams.kv_unified,
|
||||
n_ctx_per_stream,
|
||||
cparams.n_seq_max,
|
||||
cparams.n_ubatch,
|
||||
padding);
|
||||
|
@ -16655,7 +16792,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
|||
params.type_v,
|
||||
!cparams.flash_attn,
|
||||
cparams.offload_kqv,
|
||||
cparams.n_ctx,
|
||||
cparams.kv_unified,
|
||||
n_ctx_per_stream,
|
||||
cparams.n_seq_max,
|
||||
padding,
|
||||
hparams.n_swa,
|
||||
|
@ -16738,6 +16876,11 @@ llm_graph_result_ptr llama_model::build_graph(
|
|||
{
|
||||
llm = std::make_unique<llm_build_qwen2>(*this, params, gf);
|
||||
} break;
|
||||
case LLM_ARCH_DREAM:
|
||||
{
|
||||
llm = std::make_unique<llm_build_dream>(*this, params, gf);
|
||||
}
|
||||
break;
|
||||
case LLM_ARCH_QWEN2VL:
|
||||
{
|
||||
llm = std::make_unique<llm_build_qwen2vl>(*this, params, gf);
|
||||
|
@ -17155,6 +17298,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
|
|||
case LLM_ARCH_BITNET:
|
||||
case LLM_ARCH_QWEN:
|
||||
case LLM_ARCH_QWEN2:
|
||||
case LLM_ARCH_DREAM:
|
||||
case LLM_ARCH_QWEN2MOE:
|
||||
case LLM_ARCH_QWEN3:
|
||||
case LLM_ARCH_QWEN3MOE:
|
||||
|
|
|
@ -3635,6 +3635,10 @@ llama_token llama_vocab::token_fim_sep() const {
|
|||
return pimpl->special_fim_sep_id;
|
||||
}
|
||||
|
||||
llama_token llama_vocab::token_mask() const {
|
||||
return pimpl->special_mask_id;
|
||||
}
|
||||
|
||||
bool llama_vocab::get_add_space_prefix() const {
|
||||
return pimpl->add_space_prefix;
|
||||
}
|
||||
|
@ -3882,6 +3886,10 @@ llama_token llama_vocab_fim_sep(const struct llama_vocab * vocab) {
|
|||
return vocab->token_fim_sep();
|
||||
}
|
||||
|
||||
llama_token llama_vocab_mask(const struct llama_vocab* vocab) {
|
||||
return vocab->token_mask();
|
||||
}
|
||||
|
||||
// deprecated
|
||||
const char * llama_token_get_text(const struct llama_vocab * vocab, llama_token token) {
|
||||
return llama_vocab_get_text(vocab, token);
|
||||
|
|
|
@ -103,6 +103,7 @@ struct llama_vocab {
|
|||
llama_token token_sep() const;
|
||||
llama_token token_nl () const;
|
||||
llama_token token_pad() const;
|
||||
llama_token token_mask() const;
|
||||
|
||||
llama_token token_prefix() const;
|
||||
llama_token token_middle() const;
|
||||
|
|
|
@ -127,7 +127,6 @@ struct slot_params {
|
|||
std::vector<std::string> response_fields;
|
||||
bool timings_per_token = false;
|
||||
bool post_sampling_probs = false;
|
||||
bool ignore_eos = false;
|
||||
|
||||
struct common_params_sampling sampling;
|
||||
struct common_params_speculative speculative;
|
||||
|
@ -441,7 +440,6 @@ struct server_task {
|
|||
|
||||
{
|
||||
params.sampling.logit_bias.clear();
|
||||
params.ignore_eos = json_value(data, "ignore_eos", false);
|
||||
|
||||
const auto & logit_bias = data.find("logit_bias");
|
||||
if (logit_bias != data.end() && logit_bias->is_array()) {
|
||||
|
@ -472,6 +470,13 @@ struct server_task {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
params.sampling.ignore_eos = json_value(data, "ignore_eos", params_base.sampling.ignore_eos);
|
||||
if (params.sampling.ignore_eos) {
|
||||
params.sampling.logit_bias.insert(
|
||||
params.sampling.logit_bias.end(),
|
||||
defaults.sampling.logit_bias_eog.begin(), defaults.sampling.logit_bias_eog.end());
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
|
@ -1898,7 +1903,6 @@ struct server_context {
|
|||
|
||||
bool clean_kv_cache = true;
|
||||
bool add_bos_token = true;
|
||||
bool has_eos_token = false;
|
||||
|
||||
int32_t n_ctx; // total context for all clients / slots
|
||||
|
||||
|
@ -1957,7 +1961,6 @@ struct server_context {
|
|||
n_ctx = llama_n_ctx(ctx);
|
||||
|
||||
add_bos_token = llama_vocab_get_add_bos(vocab);
|
||||
has_eos_token = llama_vocab_eos(vocab) != LLAMA_TOKEN_NULL;
|
||||
|
||||
if (!params_base.speculative.model.path.empty() || !params_base.speculative.model.hf_repo.empty()) {
|
||||
SRV_INF("loading draft model '%s'\n", params_base.speculative.model.path.c_str());
|
||||
|
@ -2217,10 +2220,6 @@ struct server_context {
|
|||
slot.params.n_predict = slot.n_predict;
|
||||
}
|
||||
|
||||
if (slot.params.ignore_eos && has_eos_token) {
|
||||
slot.params.sampling.logit_bias.push_back({llama_vocab_eos(vocab), -INFINITY});
|
||||
}
|
||||
|
||||
{
|
||||
if (slot.smpl != nullptr) {
|
||||
common_sampler_free(slot.smpl);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue