Merge branch 'upstream' into concedo_experimental

# Conflicts:
#	.github/workflows/build.yml
#	README.md
#	ci/run.sh
#	docs/build.md
#	examples/CMakeLists.txt
#	examples/parallel/parallel.cpp
#	ggml/CMakeLists.txt
#	ggml/src/CMakeLists.txt
#	scripts/server-bench.py
#	src/llama-kv-cache-unified.cpp
#	tests/test-backend-ops.cpp
#	tools/batched-bench/batched-bench.cpp
#	tools/server/README.md
This commit is contained in:
Concedo 2025-07-17 00:28:37 +08:00
commit bdff33e0de
47 changed files with 3128 additions and 509 deletions

View file

@ -1466,6 +1466,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.swa_full = true;
}
).set_env("LLAMA_ARG_SWA_FULL"));
add_opt(common_arg(
{"--kv-unified", "-kvu"},
string_format("use single unified KV buffer for the KV cache of all sequences (default: %s)\n"
"[(more info)](https://github.com/ggml-org/llama.cpp/pull/14363)", params.kv_unified ? "true" : "false"),
[](common_params & params) {
params.kv_unified = true;
}
).set_env("LLAMA_ARG_KV_SPLIT"));
add_opt(common_arg(
{"--no-context-shift"},
string_format("disables context shift on infinite text generation (default: %s)", params.ctx_shift ? "disabled" : "enabled"),
@ -3425,5 +3433,34 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
// diffusion parameters
add_opt(common_arg(
{ "--diffusion-steps" }, "N",
string_format("number of diffusion steps (default: %d)", params.diffusion.steps),
[](common_params & params, int value) { params.diffusion.steps = value; }
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
add_opt(common_arg(
{ "--diffusion-eps" }, "F",
string_format("epsilon for timesteps (default: %.6f)", (double) params.diffusion.eps),
[](common_params & params, const std::string & value) { params.diffusion.eps = std::stof(value); }
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
add_opt(common_arg(
{ "--diffusion-algorithm" }, "N",
string_format("diffusion algorithm: 0=ORIGIN, 1=MASKGIT_PLUS, 2=TOPK_MARGIN, 3=ENTROPY (default: %d)",
params.diffusion.algorithm),
[](common_params & params, int value) { params.diffusion.algorithm = value; }
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
add_opt(common_arg(
{ "--diffusion-alg-temp" }, "F",
string_format("algorithm temperature (default: %.3f)", (double) params.diffusion.alg_temp),
[](common_params & params, const std::string & value) { params.diffusion.alg_temp = std::stof(value); }
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
add_opt(common_arg(
{ "--diffusion-visual" },
string_format("enable visual diffusion mode (show progressive generation) (default: %s)",
params.diffusion.visual_mode ? "true" : "false"),
[](common_params & params) { params.diffusion.visual_mode = true; }
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
return ctx_arg;
}

View file

@ -1013,15 +1013,21 @@ struct common_init_result common_init_from_params(common_params & params) {
params.sampling.ignore_eos = false;
}
if (params.sampling.ignore_eos) {
for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) {
if (llama_vocab_is_eog(vocab, i)) {
LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(lctx, i).c_str(), -INFINITY);
params.sampling.logit_bias.push_back({i, -INFINITY});
}
// initialize once
for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) {
if (llama_vocab_is_eog(vocab, i)) {
LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(lctx, i).c_str(), -INFINITY);
params.sampling.logit_bias_eog.push_back({i, -INFINITY});
}
}
if (params.sampling.ignore_eos) {
// add EOG biases to the active set of logit biases
params.sampling.logit_bias.insert(
params.sampling.logit_bias.end(),
params.sampling.logit_bias_eog.begin(), params.sampling.logit_bias_eog.end());
}
if (params.sampling.penalty_last_n == -1) {
LOG_INF("%s: setting penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
params.sampling.penalty_last_n = llama_n_ctx(lctx);
@ -1165,6 +1171,7 @@ struct llama_context_params common_context_params_to_llama(const common_params &
cparams.no_perf = params.no_perf;
cparams.op_offload = !params.no_op_offload;
cparams.swa_full = params.swa_full;
cparams.kv_unified = params.kv_unified;
cparams.type_k = params.cache_type_k;
cparams.type_v = params.cache_type_v;

View file

@ -77,6 +77,7 @@ enum llama_example {
LLAMA_EXAMPLE_LOOKUP,
LLAMA_EXAMPLE_PARALLEL,
LLAMA_EXAMPLE_TTS,
LLAMA_EXAMPLE_DIFFUSION,
LLAMA_EXAMPLE_COUNT,
};
@ -173,7 +174,8 @@ struct common_params_sampling {
std::vector<common_grammar_trigger> grammar_triggers; // optional triggers (for lazy grammars)
std::set<llama_token> preserved_tokens;
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
std::vector<llama_logit_bias> logit_bias_eog; // pre-calculated logit biases for EOG tokens
// print the parameters into a string
std::string print() const;
@ -213,6 +215,14 @@ struct common_params_vocoder {
bool use_guide_tokens = false; // enable guide tokens to improve TTS accuracy // NOLINT
};
struct common_params_diffusion {
int32_t steps = 64; // number of diffusion steps
float eps = 1e-3f; // epsilon for timesteps
int32_t algorithm = 0; // diffusion algorithm (0=ORIGIN, 1=MASKGIT_PLUS, 2=TOPK_MARGIN, 3=ENTROPY)
float alg_temp = 0.0f; // algorithm temperature
bool visual_mode = false; // show progressive diffusion on screen
};
enum common_reasoning_format {
COMMON_REASONING_FORMAT_NONE,
COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY, // Extract thinking tag contents and return as `message.reasoning_content`, or leave inline in <think> tags in stream mode
@ -264,6 +274,7 @@ struct common_params {
struct common_params_sampling sampling;
struct common_params_speculative speculative;
struct common_params_vocoder vocoder;
struct common_params_diffusion diffusion;
struct common_params_model model;
@ -326,6 +337,7 @@ struct common_params {
bool no_perf = false; // disable performance metrics
bool ctx_shift = true; // context shift on inifinite text generation
bool swa_full = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
bool kv_unified = false; // enable unified KV cache
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
bool use_mmap = true; // use mmap for faster loads

View file

@ -669,6 +669,36 @@ class TextModel(ModelBase):
# NOTE: if you get an error here, you need to update the convert_hf_to_gguf_update.py script
# or pull the latest version of the model from Huggingface
# don't edit the hashes manually!
if chkhsh == "b6e8e1518dc4305be2fe39c313ed643381c4da5db34a98f6a04c093f8afbe99b":
# ref: https://huggingface.co/THUDM/glm-4-9b-chat
res = "chatglm-bpe"
if chkhsh == "81d72c7348a9f0ebe86f23298d37debe0a5e71149e29bd283904c02262b27516":
# ref: https://huggingface.co/THUDM/glm-4-9b-chat
res = "chatglm-bpe"
if chkhsh == "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2":
# ref: https://huggingface.co/THUDM/glm-4-9b-hf
res = "glm4"
if chkhsh == "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35":
# ref: https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0
res = "minerva-7b"
if chkhsh == "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664":
# ref: https://huggingface.co/tencent/Hunyuan-A13B-Instruct
res = "hunyuan"
if chkhsh == "a6b57017d60e6edb4d88ecc2845188e0eb333a70357e45dcc9b53964a73bbae6":
# ref: https://huggingface.co/tiiuae/Falcon-H1-0.5B-Base
res = "falcon-h1"
if chkhsh == "60476e1243776c4fb1b993dbd7a5f15ac22f83c80afdf425fa5ae01c8d44ef86":
# ref: https://huggingface.co/tiiuae/Falcon-H1-1B-Base
res = "falcon-h1"
if chkhsh == "3eda48b4c4dc7de733d1a8b3e3b4a85243dbbf704da2ee9d42c6beced8897896":
# ref: https://huggingface.co/tiiuae/Falcon-H1-7B-Base
res = "falcon-h1"
if chkhsh == "48f8e02c0359c0bbdd82f26909171fac1c18a457bb47573ed1fe3bbb2c1cfd4b":
# ref: https://huggingface.co/tiiuae/Falcon-H1-34B-Base
res = "falcon-h1"
if chkhsh == "81212dc7cdb7e0c1074ca62c5aeab0d43c9f52b8a737be7b12a777c953027890":
# ref: https://huggingface.co/moonshotai/Kimi-K2-Base
res = "kimi-k2"
if chkhsh == "0ef9807a4087ebef797fc749390439009c3b9eda9ad1a097abbe738f486c01e5":
# ref: https://huggingface.co/meta-llama/Meta-Llama-3-8B
res = "llama-bpe"
@ -804,45 +834,15 @@ class TextModel(ModelBase):
if chkhsh == "d5f1dd6f980fec569fb218a81a7658ac45fc56b38c5a0adeb1c232fbe04ef5ec":
# ref: https://huggingface.co/ByteDance-Seed/Seed-Coder-8B-Base
res = "seed-coder"
if chkhsh == "b6e8e1518dc4305be2fe39c313ed643381c4da5db34a98f6a04c093f8afbe99b":
# ref: https://huggingface.co/THUDM/glm-4-9b-chat
res = "chatglm-bpe"
if chkhsh == "81d72c7348a9f0ebe86f23298d37debe0a5e71149e29bd283904c02262b27516":
# ref: https://huggingface.co/THUDM/glm-4-9b-chat
res = "chatglm-bpe"
if chkhsh == "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2":
# ref: https://huggingface.co/THUDM/glm-4-9b-hf
res = "glm4"
if chkhsh == "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35":
# ref: https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0
res = "minerva-7b"
if chkhsh == "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664":
# ref: https://huggingface.co/tencent/Hunyuan-A13B-Instruct
res = "hunyuan"
if chkhsh == "b0a6b1c0bd5998ebd9df08611efde34a4ff03faed45ae09c43e6b31ebd4b94cf":
# ref: https://huggingface.co/skt/A.X-4.0
res = "a.x-4.0"
if chkhsh == "a6b57017d60e6edb4d88ecc2845188e0eb333a70357e45dcc9b53964a73bbae6":
# ref: https://huggingface.co/tiiuae/Falcon-H1-0.5B-Base
res = "falcon-h1"
if chkhsh == "60476e1243776c4fb1b993dbd7a5f15ac22f83c80afdf425fa5ae01c8d44ef86":
# ref: https://huggingface.co/tiiuae/Falcon-H1-1B-Base
res = "falcon-h1"
if chkhsh == "3eda48b4c4dc7de733d1a8b3e3b4a85243dbbf704da2ee9d42c6beced8897896":
# ref: https://huggingface.co/tiiuae/Falcon-H1-7B-Base
res = "falcon-h1"
if chkhsh == "48f8e02c0359c0bbdd82f26909171fac1c18a457bb47573ed1fe3bbb2c1cfd4b":
# ref: https://huggingface.co/tiiuae/Falcon-H1-34B-Base
res = "falcon-h1"
if chkhsh == "f6791d196f87ce6b56a7d234be618e0d58f8cda3549416635b2bebcd22cd95c4":
# ref: https://huggingface.co/K-intelligence/Midm-2.0-Base-Instruct
res = "midm-2.0"
if chkhsh == "169bf0296a13c4d9b7672313f749eb36501d931022de052aad6e36f2bf34dd51":
# ref: https://huggingface.co/LiquidAI/LFM2-Tokenizer
res = "lfm2"
if chkhsh == "81212dc7cdb7e0c1074ca62c5aeab0d43c9f52b8a737be7b12a777c953027890":
# ref: https://huggingface.co/moonshotai/Kimi-K2-Base
res = "kimi-k2"
if res is None:
logger.warning("\n")
@ -2778,6 +2778,76 @@ class Qwen2Model(TextModel):
yield from super().modify_tensors(data_torch, name, bid)
@ModelBase.register("DreamModel")
class DreamModel(TextModel):
model_arch = gguf.MODEL_ARCH.DREAM
def get_vocab_base(self) -> tuple[list[str], list[int], str]:
tokens: list[str] = []
toktypes: list[int] = []
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True)
vocab_dict = tokenizer.get_vocab()
vocab_size = self.hparams.get("vocab_size", len(vocab_dict))
assert max(vocab_dict.values()) < vocab_size
tokpre = self.get_vocab_base_pre(tokenizer)
reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in vocab_dict.items()}
added_vocab = tokenizer.get_added_vocab()
for i in range(vocab_size):
if i not in reverse_vocab:
tokens.append(f"[PAD{i}]")
toktypes.append(gguf.TokenType.UNUSED)
elif reverse_vocab[i] in added_vocab:
tokens.append(reverse_vocab[i])
# Check if it's a special token - treat special tokens as CONTROL tokens
if hasattr(tokenizer, 'added_tokens_decoder') and i in tokenizer.added_tokens_decoder:
if tokenizer.added_tokens_decoder[i].special:
toktypes.append(gguf.TokenType.CONTROL)
else:
toktypes.append(gguf.TokenType.USER_DEFINED)
else:
# Fallback: treat all added vocab as control tokens for special tokens like <|im_start|>
toktypes.append(gguf.TokenType.CONTROL)
else:
tokens.append(reverse_vocab[i])
toktypes.append(gguf.TokenType.NORMAL)
return tokens, toktypes, tokpre
def set_vocab(self):
try:
self._set_vocab_sentencepiece()
except FileNotFoundError:
self._set_vocab_gpt2()
def set_gguf_parameters(self):
super().set_gguf_parameters()
self._try_set_pooling_type()
# Dream models use non-causal attention for diffusion
self.gguf_writer.add_causal_attention(False)
# Handle RoPE scaling similar to Qwen2
rope_scaling = self.hparams.get("rope_scaling") or {}
if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling:
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"])
# Add Dream-specific parameters
mask_token_id = self.hparams.get("mask_token_id")
if mask_token_id is not None:
self.gguf_writer.add_mask_token_id(mask_token_id)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# Dream model tensors should be mapped directly since it's the base model
yield from super().modify_tensors(data_torch, name, bid)
@ModelBase.register("Ernie4_5_ForCausalLM")
class Ernie4_5Model(TextModel):
model_arch = gguf.MODEL_ARCH.ERNIE4_5

View file

@ -232,7 +232,7 @@ for model in models:
# generate the source code for the convert_hf_to_gguf.py:get_vocab_base_pre() function:
src_ifs = ""
for model in [*all_models, *pre_computed_hashes]:
for model in [*pre_computed_hashes, *all_models]:
name = model["name"]
tokt = model["tokt"]
chkhsh = model.get("chkhsh")
@ -240,11 +240,6 @@ for model in [*all_models, *pre_computed_hashes]:
if tokt == TOKENIZER_TYPE.SPM or tokt == TOKENIZER_TYPE.UGM:
continue
# Skip if the tokenizer folder does not exist or there are other download issues previously
if not os.path.exists(f"models/tokenizers/{name}"):
logger.warning(f"Directory for tokenizer {name} not found. Skipping...")
continue
# create the tokenizer
if chkhsh is not None:
# if the model has a pre-computed hash, use it
@ -254,6 +249,12 @@ for model in [*all_models, *pre_computed_hashes]:
chkhsh = existing_models[name]
else:
# otherwise, compute the hash of the tokenizer
# Skip if the tokenizer folder does not exist or there are other download issues previously
if not os.path.exists(f"models/tokenizers/{name}"):
logger.warning(f"Directory for tokenizer {name} not found. Skipping...")
continue
try:
logger.info(f"Loading tokenizer from {f'models/tokenizers/{name}'}...")
if name == "t5":

View file

@ -0,0 +1,5 @@
set(TARGET llama-diffusion-cli)
add_executable(${TARGET} diffusion-cli.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE llama common ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_17)

View file

@ -0,0 +1,507 @@
#include "arg.h"
#include "chat.h"
#include "common.h"
#include "llama.h"
#include "log.h"
#include <limits.h>
#include <string>
#include <vector>
#include <algorithm>
#include <cmath>
#include <limits>
#include <random>
typedef bool (*diffusion_step_callback_t)(int32_t step,
int32_t total_steps,
const llama_token * tokens,
int32_t n_tokens,
void * user_data);
enum diffusion_alg {
DIFFUSION_ALG_ORIGIN = 0,
DIFFUSION_ALG_MASKGIT_PLUS = 1,
DIFFUSION_ALG_TOPK_MARGIN = 2,
DIFFUSION_ALG_ENTROPY = 3,
};
struct diffusion_params {
int32_t steps;
float eps;
float temperature;
float top_p;
int32_t top_k;
llama_token mask_token_id;
enum diffusion_alg algorithm;
float alg_temp;
diffusion_step_callback_t step_callback;
void * step_callback_user_data;
int32_t seed;
};
static diffusion_params diffusion_default_params() {
diffusion_params params = {};
params.steps = 64;
params.eps = 1e-3f;
params.temperature = 0.2f;
params.top_p = 0.95f;
params.top_k = 0;
params.mask_token_id = LLAMA_TOKEN_NULL;
params.algorithm = DIFFUSION_ALG_ORIGIN;
params.alg_temp = 0.0f;
params.step_callback = nullptr;
params.step_callback_user_data = nullptr;
params.seed = 0;
return params;
}
static void diffusion_generate(llama_context * ctx,
const llama_token * input_tokens,
llama_token * output_tokens,
int32_t n_input,
int32_t max_length,
struct diffusion_params params,
int32_t & n_generated) {
n_generated = 0;
if (!ctx || !input_tokens || !output_tokens || n_input <= 0 || max_length <= n_input) {
return;
}
const llama_model * model = llama_get_model(ctx);
// Initialize with input and pad with mask tokens
std::copy(input_tokens, input_tokens + n_input, output_tokens);
std::fill(output_tokens + n_input, output_tokens + max_length, params.mask_token_id);
std::mt19937 rng(params.seed);
std::vector<float> timesteps(params.steps + 1);
for (int32_t i = 0; i <= params.steps; i++) {
timesteps[i] = 1.0f - (float) i / params.steps * (1.0f - params.eps);
}
llama_set_causal_attn(ctx, false);
int32_t n_vocab = llama_vocab_n_tokens(llama_model_get_vocab(model));
std::vector<llama_token_data> candidates(n_vocab);
std::vector<llama_token_data> conf_candidates;
conf_candidates.reserve(max_length);
std::vector<int32_t> mask_positions;
mask_positions.reserve(max_length);
struct llama_sampler * sampler = llama_sampler_chain_init(llama_sampler_chain_default_params());
if (params.top_k > 0) {
llama_sampler_chain_add(sampler, llama_sampler_init_top_k(params.top_k));
}
if (params.top_p < 1.0f) {
llama_sampler_chain_add(sampler, llama_sampler_init_top_p(params.top_p, 1));
}
if (params.temperature > 0.0f) {
llama_sampler_chain_add(sampler, llama_sampler_init_temp(params.temperature));
}
llama_sampler_chain_add(sampler, llama_sampler_init_dist(params.seed));
struct llama_sampler * dist_sampler = llama_sampler_init_dist(params.seed);
llama_batch batch = llama_batch_init(max_length, 0, 1);
batch.n_tokens = max_length;
int64_t total_sampling_time = 0;
int64_t total_time = 0;
int64_t time_start = ggml_time_us();
for (int32_t step = 0; step < params.steps; step++) {
if (params.step_callback) {
if (!params.step_callback(step, params.steps, output_tokens, max_length, params.step_callback_user_data)) {
break;
}
}
for (int32_t i = 0; i < max_length; i++) {
batch.token[i] = output_tokens[i];
batch.pos[i] = i;
batch.n_seq_id[i] = 1;
batch.seq_id[i][0] = 0;
batch.logits[i] = 1;
}
int ret = llama_decode(ctx, batch);
if (ret != 0) {
LOG_ERR("%s: failed to decode at step %d, ret = %d\n", __func__, step, ret);
break;
}
float * raw_logits = llama_get_logits(ctx);
if (!raw_logits) {
LOG_ERR("%s: failed to get logits at step %d\n", __func__, step);
break;
}
auto get_logits_for_pos = [&](int32_t pos) -> const float * {
return pos == 0 ? raw_logits : raw_logits + (pos - 1) * n_vocab;
};
int64_t time_start_sampling = ggml_time_us();
mask_positions.clear();
for (int32_t i = 0; i < max_length; i++) {
if (output_tokens[i] == params.mask_token_id) {
mask_positions.push_back(i);
}
}
if (mask_positions.empty()) {
break;
}
float t = timesteps[step];
float s = timesteps[step + 1];
if (params.algorithm == DIFFUSION_ALG_ORIGIN) {
float p_transfer = (step < params.steps - 1) ? (1.0f - s / t) : 1.0f;
for (int32_t pos : mask_positions) {
if (std::uniform_real_distribution<float>(0.0f, 1.0f)(rng) < p_transfer) {
const float * pos_logits = get_logits_for_pos(pos);
for (int32_t token_id = 0; token_id < n_vocab; token_id++) {
candidates[token_id].id = token_id;
candidates[token_id].logit = pos_logits[token_id];
candidates[token_id].p = 0.0f;
}
llama_token_data_array cur_p = {
/* .data = */ candidates.data(),
/* .size = */ (size_t) n_vocab, // Reset size to full vocab
/* .selected = */ -1,
/* .sorted = */ false,
};
llama_sampler_apply(sampler, &cur_p);
output_tokens[pos] = cur_p.data[cur_p.selected].id;
}
}
} else {
std::vector<std::pair<float, int32_t>> confidences;
std::vector<llama_token> sampled_tokens(mask_positions.size());
for (size_t i = 0; i < mask_positions.size(); i++) {
int32_t pos = mask_positions[i];
const float * pos_logits = get_logits_for_pos(pos);
for (int32_t token_id = 0; token_id < n_vocab; token_id++) {
candidates[token_id].logit = pos_logits[token_id];
candidates[token_id].p = 0.0f;
candidates[token_id].id = token_id;
}
llama_token_data_array cur_p = {
/* .data = */ candidates.data(),
/* .size = */ candidates.size(),
/* .selected = */ -1,
/* .sorted = */ false,
};
llama_sampler_apply(sampler, &cur_p);
llama_token sampled_token = cur_p.data[cur_p.selected].id;
float confidence = 0.0f;
if (params.algorithm == DIFFUSION_ALG_ENTROPY) {
const float epsilon = 1e-10f;
for (size_t j = 0; j < cur_p.size; j++) {
float prob = cur_p.data[j].p;
confidence += prob * logf(prob + epsilon);
}
} else if (params.algorithm == DIFFUSION_ALG_TOPK_MARGIN) {
confidence = cur_p.data[0].p - cur_p.data[1].p;
} else {
confidence = cur_p.data[cur_p.selected].p;
}
sampled_tokens[i] = sampled_token;
confidences.emplace_back(confidence, i);
}
int32_t num_transfer =
(step < params.steps - 1) ? (int32_t) (mask_positions.size() * (1.0f - s / t)) : mask_positions.size();
if (num_transfer > 0) {
if (params.alg_temp == 0.0f) {
std::partial_sort(confidences.begin(), confidences.begin() + num_transfer, confidences.end(),
[](const std::pair<float, int32_t> & a, const std::pair<float, int32_t> & b) {
if (a.first != b.first) {
return a.first > b.first;
}
return a.second < b.second;
});
} else {
conf_candidates.clear();
for (int32_t pos = 0; pos < max_length; pos++) {
float conf_logit = -std::numeric_limits<float>::infinity();
auto it = std::find(mask_positions.begin(), mask_positions.end(), pos);
if (it != mask_positions.end()) {
size_t mask_idx = std::distance(mask_positions.begin(), it);
conf_logit = confidences[mask_idx].first / params.alg_temp; // Apply temperature scaling
}
conf_candidates.emplace_back(llama_token_data{ pos, conf_logit, 0.0f });
}
llama_token_data_array conf_array = {
/* .data = */ conf_candidates.data(),
/* .size = */ conf_candidates.size(),
/* .selected = */ -1,
/* .sorted = */ false,
};
for (int32_t i = 0; i < num_transfer; i++) {
// Apply distribution sampler to get selected index
llama_sampler_apply(dist_sampler, &conf_array);
int selected_idx = conf_array.selected;
confidences[i].second = conf_candidates[selected_idx].id;
conf_candidates[selected_idx].p = 0.0f;
conf_array.selected = -1;
}
}
if (params.alg_temp == 0.0f) {
// Deterministic - use confidence order
for (int32_t i = 0; i < num_transfer; i++) {
int32_t mask_idx = confidences[i].second;
int32_t pos = mask_positions[mask_idx];
llama_token token = sampled_tokens[mask_idx];
output_tokens[pos] = token;
}
} else {
for (int32_t i = 0; i < num_transfer; i++) {
int32_t pos = confidences[i].second;
auto it = std::find(mask_positions.begin(), mask_positions.end(), pos);
if (it != mask_positions.end()) {
int32_t mask_idx = std::distance(mask_positions.begin(), it);
output_tokens[pos] = sampled_tokens[mask_idx];
}
}
}
}
}
int64_t time_end_sampling = ggml_time_us();
total_sampling_time += time_end_sampling - time_start_sampling;
}
int64_t time_end = ggml_time_us();
total_time += time_end - time_start;
LOG_INF("\ntotal time: %0.2fms, time per step: %0.2fms, sampling time per step: %0.2fms\n",
total_time / 1000.0, total_time / 1000.0 / params.steps, total_sampling_time / 1000.0 / params.steps);
llama_batch_free(batch);
llama_sampler_free(sampler);
llama_sampler_free(dist_sampler);
n_generated = max_length;
}
static std::string format_input_text(const std::string & prompt, bool use_chat_template, llama_model * model) {
if (!use_chat_template) {
return prompt;
}
auto chat_templates = common_chat_templates_init(model, "");
common_chat_templates_inputs inputs;
common_chat_msg user_msg;
user_msg.role = "user";
user_msg.content = prompt;
inputs.add_generation_prompt = true;
inputs.messages.push_back(user_msg);
auto result = common_chat_templates_apply(chat_templates.get(), inputs);
return result.prompt;
}
struct callback_data {
const common_params_diffusion * diff_params;
const llama_vocab * vocab;
int32_t n_input;
};
static bool diffusion_step_callback(int32_t step,
int32_t total_steps,
const llama_token * tokens,
int32_t n_tokens,
void * user_data) {
(void)user_data;
callback_data * data = static_cast<callback_data *>(user_data);
auto print_progress_bar = [](int32_t step, int32_t total_steps) {
int progress_percent = (step * 100) / total_steps;
int progress_bars = (step * 50) / total_steps;
LOG_INF("\rdiffusion step: %d/%d [%s%s] %d%%",
step,
total_steps,
std::string(progress_bars, '=').c_str(),
std::string(50 - progress_bars, ' ').c_str(),
progress_percent);
};
if (data->diff_params->visual_mode) {
// Visual mode: clear
LOG_INF("\033[2J\033[H"); // Clear screen and move cursor to top-left
print_progress_bar(step, total_steps);
LOG_INF("\n");
std::string current_text = " ";
for (int32_t i = data->n_input; i < n_tokens; i++) {
std::string token_str;
if (tokens[i] != llama_vocab_mask(data->vocab)) {
char piece[256];
int n_chars = llama_token_to_piece(data->vocab, tokens[i], piece, sizeof(piece), 0, false);
if (n_chars > 0) {
piece[n_chars] = '\0';
token_str = piece;
}
} else {
token_str = " ";
}
current_text += token_str;
}
LOG_INF("%s\n", current_text.c_str());
} else {
print_progress_bar(step, total_steps);
}
return true;
}
int main(int argc, char ** argv) {
ggml_time_init();
common_params params;
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_DIFFUSION)) {
return 1;
}
const char * alg_names[] = { "ORIGIN", "MASKGIT_PLUS", "TOPK_MARGIN", "ENTROPY" };
const char * alg_name = (params.diffusion.algorithm >= 0 && params.diffusion.algorithm <= 3) ?
alg_names[params.diffusion.algorithm] :
"UNKNOWN";
common_init();
llama_backend_init();
llama_model_params model_params = llama_model_default_params();
model_params.n_gpu_layers = params.n_gpu_layers;
model_params.devices = params.devices.data();
model_params.use_mmap = params.use_mmap;
model_params.use_mlock = params.use_mlock;
model_params.check_tensors = params.check_tensors;
llama_model * model = llama_model_load_from_file(params.model.path.c_str(), model_params);
if (!model) {
LOG_ERR("error: failed to load model '%s'\n", params.model.path.c_str());
return 1;
}
llama_context_params ctx_params = llama_context_default_params();
ctx_params.n_ctx = params.n_ctx;
ctx_params.n_batch = params.n_batch;
ctx_params.n_ubatch = params.n_ubatch;
ctx_params.flash_attn = params.flash_attn;
ctx_params.no_perf = params.no_perf;
ctx_params.type_k = params.cache_type_k;
ctx_params.type_v = params.cache_type_v;
llama_context * ctx = llama_init_from_model(model, ctx_params);
if (!ctx) {
LOG_ERR("error: failed to create context\n");
llama_model_free(model);
return 1;
}
llama_set_n_threads(ctx, params.cpuparams.n_threads, params.cpuparams_batch.n_threads);
const llama_vocab * vocab = llama_model_get_vocab(model);
std::string formatted_prompt = format_input_text(params.prompt, params.enable_chat_template, model);
std::vector<llama_token> input_tokens = common_tokenize(vocab, formatted_prompt,
/*add special tokens*/ true,
/*parse special*/ true);
int n_input = input_tokens.size();
if (n_input >= params.n_ctx) {
LOG_ERR("error: input too long (%d tokens), max context is %d\n", n_input, params.n_ctx);
llama_free(ctx);
llama_model_free(model);
return 1;
}
struct diffusion_params ldiff_params = diffusion_default_params();
ldiff_params.steps = params.diffusion.steps;
ldiff_params.eps = params.diffusion.eps;
ldiff_params.temperature = params.sampling.temp;
ldiff_params.top_p = params.sampling.top_p;
ldiff_params.top_k = params.sampling.top_k;
ldiff_params.algorithm = static_cast<enum diffusion_alg>(params.diffusion.algorithm);
ldiff_params.alg_temp = params.diffusion.alg_temp;
ldiff_params.seed = params.sampling.seed;
llama_token mask_token_id = llama_vocab_mask(vocab);
GGML_ASSERT(mask_token_id != LLAMA_TOKEN_NULL);
LOG_INF("diffusion_params: - %-25s llama_token = %d\n", "mask_token_id", mask_token_id);
LOG_INF("diffusion_params: - %-25s u32 = %d\n", "steps", params.diffusion.steps);
LOG_INF("diffusion_params: - %-25s f32 = %.6f\n", "eps", params.diffusion.eps);
LOG_INF("diffusion_params: - %-25s u32 = %d (%s)\n", "algorithm", params.diffusion.algorithm,
alg_name);
LOG_INF("diffusion_params: - %-25s f32 = %.3f\n", "alg_temp", params.diffusion.alg_temp);
ldiff_params.mask_token_id = mask_token_id;
callback_data cb_data = { &params.diffusion, vocab, n_input };
ldiff_params.step_callback = diffusion_step_callback;
ldiff_params.step_callback_user_data = &cb_data;
int32_t n_generated = 0;
std::vector<llama_token> output_tokens(params.n_ubatch);
diffusion_generate(ctx, input_tokens.data(), output_tokens.data(), n_input, params.n_ubatch,
ldiff_params, n_generated);
if (n_generated > 0) {
if (params.diffusion.visual_mode) {
//clear screen and move cursor to top-left
LOG_INF("\033[2J\033[H");
}
output_tokens.erase(output_tokens.begin(), output_tokens.begin() + n_input);
std::string output_data = common_detokenize(vocab, output_tokens, false);
LOG_INF("\n%s\n", output_data.c_str());
} else {
LOG_INF("Error: diffusion generation failed\n");
}
llama_free(ctx);
llama_model_free(model);
llama_backend_free();
return 0;
}

View file

@ -107,7 +107,7 @@ int main(int argc, char ** argv) {
const llama_vocab * vocab = llama_model_get_vocab(model);
const int n_ctx_train = llama_model_n_ctx_train(model);
const int n_ctx = llama_n_ctx(ctx);
const int n_ctx = llama_n_ctx(ctx);
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);

View file

@ -0,0 +1,19 @@
#pragma once
#include "ggml.h"
#include "ggml-backend.h"
#ifdef __cplusplus
extern "C" {
#endif
#define GGML_WEBGPU_NAME "WebGPU"
// Needed for examples in ggml
GGML_BACKEND_API ggml_backend_t ggml_backend_webgpu_init(void);
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_webgpu_reg(void);
#ifdef __cplusplus
}
#endif

View file

@ -45,6 +45,10 @@
#include "ggml-vulkan.h"
#endif
#ifdef GGML_USE_WEBGPU
#include "ggml-webgpu.h"
#endif
#ifdef GGML_USE_OPENCL
#include "ggml-opencl.h"
#endif
@ -173,6 +177,9 @@ struct ggml_backend_registry {
#ifdef GGML_USE_VULKAN
register_backend(ggml_backend_vk_reg());
#endif
#ifdef GGML_USE_WEBGPU
register_backend(ggml_backend_webgpu_reg());
#endif
#ifdef GGML_USE_OPENCL
register_backend(ggml_backend_opencl_reg());
#endif

View file

@ -4015,6 +4015,9 @@ static void ggml_compute_forward_rms_norm_f32(
const float scale = 1.0f/sqrtf(mean + eps);
// if you hit this, likely you got an inf somewhere earlier
assert(scale > 0.0f);
ggml_vec_scale_f32(ne00, y, scale);
}
}

View file

@ -221,6 +221,9 @@ void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * G
for (int i = np; i < n; ++i) {
sumf += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[i])*GGML_CPU_FP16_TO_FP32(y[i]));
}
// if you hit this, you are likely running outside the FP range
assert(!isnan(sumf) && !isinf(sumf));
#else
for (int i = 0; i < n; ++i) {
sumf += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[i])*GGML_CPU_FP16_TO_FP32(y[i]));

View file

@ -33,8 +33,10 @@ typedef void (* fattn_kernel_t)(
const int ne13,
const int ne31,
const int ne32,
const int ne33,
const int nb31,
const int nb32,
const int nb33,
const int nb01,
const int nb02,
const int nb03,
@ -521,7 +523,7 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
template<int D, int ncols1, int ncols2> // D == head size
__launch_bounds__(D, 1)
static __global__ void flash_attn_stream_k_fixup(
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) {
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, const int ne11) {
constexpr int ncols = ncols1*ncols2;
const int bidx0 = blockIdx.x;
@ -535,8 +537,8 @@ static __global__ void flash_attn_stream_k_fixup(
const int iter_k = ne11 / FATTN_KQ_STRIDE;
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
const int kbc0 = (bidx0 + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
const int kbc0 = (bidx0 + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
const int kbc0_stop = (bidx0 + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
const bool did_not_have_any_data = kbc0 == kbc0_stop;
const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
@ -545,14 +547,15 @@ static __global__ void flash_attn_stream_k_fixup(
return;
}
const int channel = kbc0 / (iter_k*iter_j);
const int jt = (kbc0 - channel*iter_k*iter_j) / iter_k;
const int sequence = kbc0 / (iter_k*iter_j*(ne02/ncols2));
const int head = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
const int jt = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
if (jt*ncols1 + j >= ne01) {
return;
}
dst += jt*ne02*(ncols1*D) + channel*(ncols2*D) + (j*ne02 + c)*D + tid;
dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + head*(ncols2*D) + (j*ne02 + c)*D + tid;
// Load the partial result that needs a fixup:
float dst_val = 0.0f;
@ -571,7 +574,7 @@ static __global__ void flash_attn_stream_k_fixup(
int bidx = bidx0 - 1;
int kbc_stop = kbc0;
while(true) {
const int kbc = bidx*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
const int kbc = bidx*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
if (kbc == kbc_stop) { // Did not have any data.
bidx--;
kbc_stop = kbc;
@ -617,16 +620,31 @@ static __global__ void flash_attn_combine_results(
const float2 * __restrict__ VKQ_meta,
float * __restrict__ dst,
const int parallel_blocks) {
VKQ_parts += parallel_blocks*D * gridDim.z*blockIdx.x;
VKQ_meta += parallel_blocks * gridDim.z*blockIdx.x;
dst += D * gridDim.z*blockIdx.x;
// Dimension 0: threadIdx.x
// Dimension 1: blockIdx.x
// Dimension 2: blockIdx.y
// Dimension 3: blockIdx.z
// Memory layout is permuted with [0, 2, 1, 3]
const int ne01 = gridDim.x;
const int ne02 = gridDim.y;
const int col = blockIdx.x;
const int head = blockIdx.y;
const int sequence = blockIdx.z;
const int j_dst_unrolled = (sequence*ne01 + col)*ne02 + head;
VKQ_parts += j_dst_unrolled * parallel_blocks*D;
VKQ_meta += j_dst_unrolled * parallel_blocks;
dst += j_dst_unrolled * D;
const int tid = threadIdx.x;
__builtin_assume(tid < D);
extern __shared__ float2 meta[];
for (int i = tid; i < 2*parallel_blocks; i += D) {
((float *) meta)[i] = ((const float *)VKQ_meta) [blockIdx.z*(2*parallel_blocks) + i];
((float *) meta)[i] = ((const float *)VKQ_meta) [i];
}
__syncthreads();
@ -644,11 +662,11 @@ static __global__ void flash_attn_combine_results(
const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
*((uint32_t *) &KQ_max_scale) &= ftz_mask;
VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.z*D + blockIdx.z*D + tid];
VKQ_numerator += KQ_max_scale * VKQ_parts[l*D + tid];
VKQ_denominator += KQ_max_scale * meta[l].y;
}
dst[blockIdx.z*D + tid] = VKQ_numerator / VKQ_denominator;
dst[tid] = VKQ_numerator / VKQ_denominator;
}
[[noreturn]]
@ -705,8 +723,6 @@ void launch_fattn(
GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
GGML_ASSERT(Q->ne[3] == 1);
ggml_cuda_pool & pool = ctx.pool();
cudaStream_t main_stream = ctx.stream();
const int id = ggml_cuda_get_device();
@ -853,8 +869,8 @@ void launch_fattn(
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0,
mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0,
mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0,
mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0, mask ? mask->nb[3] : 0,
Q->nb[1], Q->nb[2], Q->nb[3],
nb11, nb12, nb13,
nb21, nb22, nb23,
@ -869,11 +885,11 @@ void launch_fattn(
flash_attn_stream_k_fixup<DV, ncols1, ncols2>
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]);
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1]);
}
} else if (parallel_blocks > 1) {
const dim3 block_dim_combine(DV, 1, 1);
const dim3 blocks_num_combine(Q->ne[1], 1, blocks_num.z);
const dim3 blocks_num_combine(Q->ne[1], Q->ne[2], Q->ne[3]);
const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2);
flash_attn_combine_results<DV>

View file

@ -1224,8 +1224,10 @@ static __global__ void flash_attn_ext_f16(
const int ne13,
const int ne31,
const int ne32,
const int ne33,
const int nb31,
const int nb32,
const int nb33,
const int nb01,
const int nb02,
const int nb03,
@ -1274,8 +1276,8 @@ static __global__ void flash_attn_ext_f16(
constexpr int kb_niter = FATTN_KQ_STRIDE / c::nbatch_fa; // Number of kernel iterations per assigned KQ slice.
// kbc == k block continuous, current index in continuous ijk space.
int kbc = (blockIdx.x + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
const int kbc_stop = (blockIdx.x + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
int kbc = (blockIdx.x + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
const int kbc_stop = (blockIdx.x + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
// If the seams of 2 CUDA blocks fall within an output tile their results need to be combined.
// For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup).
@ -1285,18 +1287,19 @@ static __global__ void flash_attn_ext_f16(
int kb0_start = kbc % iter_k;
int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc);
while (kbc < kbc_stop && kb0_stop == iter_k) {
const int channel = kbc / (iter_k*iter_j);
const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile.
const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
const int head = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*(head*ncols2));
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head*ncols2 / gqa_ratio));
const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
(const half2 *) (mask + nb32*(channel % ne32) + nb31*jt*ncols1);
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
(const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1);
float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head*ncols2) * (DV/2);
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head*ncols2 / gqa_ratio));
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f;
const int kb0_start_kernel = kb0_start * kb_niter;
const int kb0_stop_kernel = kb0_stop * kb_niter;
@ -1325,18 +1328,19 @@ static __global__ void flash_attn_ext_f16(
return;
}
const int channel = kbc / (iter_k*iter_j);
const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile.
const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
const int head = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*(head*ncols2));
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head*ncols2 / gqa_ratio));
const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
(const half2 *) (mask + nb32*(channel % ne32) + nb31*jt*ncols1);
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
(const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1);
float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head*ncols2) * (DV/2);
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head*ncols2 / gqa_ratio));
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f;
const int kb0_start_kernel = kb0_start * kb_niter;
const int kb0_stop_kernel = kb0_stop * kb_niter;

View file

@ -31,8 +31,10 @@ static __global__ void flash_attn_tile_ext_f16(
const int ne13,
const int ne31,
const int ne32,
const int ne33,
const int nb31,
const int nb32,
const int nb33,
const int nb01,
const int nb02,
const int nb03,
@ -62,15 +64,17 @@ static __global__ void flash_attn_tile_ext_f16(
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
const int sequence = blockIdx.z / ne02;
const int head = blockIdx.z - sequence*ne02;
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0);
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio));
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0);
const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
const int stride_KV2 = nb11 / sizeof(half2);
const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
const half slopeh = __float2half(slopef);
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
@ -255,6 +259,8 @@ static __global__ void flash_attn_tile_ext_f16(
__syncthreads();
}
float2 * dst2 = (float2 *) dst;
#pragma unroll
for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
const int j_VKQ = j_VKQ_0 + threadIdx.y;
@ -266,21 +272,21 @@ static __global__ void flash_attn_tile_ext_f16(
half kqsum_j = __low2half(kqsum[j_VKQ_0/nwarps]) + __high2half(kqsum[j_VKQ_0/nwarps]);
kqsum_j = warp_reduce_sum((float)kqsum_j);
#pragma unroll
for (int i00 = 0; i00 < D; i00 += 2*WARP_SIZE) {
const int i0 = i00 + 2*threadIdx.x;
const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)];
#pragma unroll
for (int i00 = 0; i00 < D/2; i00 += WARP_SIZE) {
const int i0 = i00 + threadIdx.x;
half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/WARP_SIZE];
if (gridDim.y == 1) {
dst_val /= __half2half2(kqsum_j);
}
const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 0] = __low2float(dst_val);
dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 1] = __high2float(dst_val);
dst2[j_dst_unrolled*(D/2) + i0] = __half22float2(dst_val);
}
if (gridDim.y != 1 && threadIdx.x == 0) {
dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
dst_meta[j_dst_unrolled] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
}
}
#else
@ -290,8 +296,8 @@ static __global__ void flash_attn_tile_ext_f16(
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);

View file

@ -31,8 +31,10 @@ static __global__ void flash_attn_tile_ext_f32(
const int ne13,
const int ne31,
const int ne32,
const int ne33,
const int nb31,
const int nb32,
const int nb33,
const int nb01,
const int nb02,
const int nb03,
@ -74,15 +76,17 @@ static __global__ void flash_attn_tile_ext_f32(
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
const int sequence = blockIdx.z / ne02;
const int head = blockIdx.z - sequence*ne02;
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0);
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio));
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0);
const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
const int stride_KV2 = nb11 / sizeof(half2);
const float slope = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
@ -265,6 +269,8 @@ static __global__ void flash_attn_tile_ext_f32(
__syncthreads();
}
float2 * dst2 = (float2 *) dst;
#pragma unroll
for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
const int j_VKQ = j_VKQ_0 + threadIdx.y;
@ -276,22 +282,22 @@ static __global__ void flash_attn_tile_ext_f32(
float kqsum_j = kqsum[j_VKQ_0/nwarps];
kqsum_j = warp_reduce_sum(kqsum_j);
#pragma unroll
for (int i00 = 0; i00 < D; i00 += 2*WARP_SIZE) {
const int i0 = i00 + 2*threadIdx.x;
const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)];
#pragma unroll
for (int i00 = 0; i00 < D/2; i00 += WARP_SIZE) {
const int i0 = i00 + threadIdx.x;
float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/WARP_SIZE];
if (gridDim.y == 1) {
dst_val.x /= kqsum_j;
dst_val.y /= kqsum_j;
}
const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 0] = dst_val.x;
dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 1] = dst_val.y;
dst2[j_dst_unrolled*(D/2) + i0] = dst_val;
}
if (gridDim.y != 1 && threadIdx.x == 0) {
dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
dst_meta[j_dst_unrolled] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
}
}
#else

View file

@ -28,8 +28,10 @@ static __global__ void flash_attn_vec_ext_f16(
const int ne13,
const int ne31,
const int ne32,
const int ne33,
const int nb31,
const int nb32,
const int nb33,
const int nb01,
const int nb02,
const int nb03,
@ -65,14 +67,16 @@ static __global__ void flash_attn_vec_ext_f16(
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
const int sequence = blockIdx.z / ne02;
const int head = blockIdx.z - sequence*ne02;
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
Q += nb02* blockIdx.z + nb01*ic0;
K += nb12*(blockIdx.z / gqa_ratio);
V += nb22*(blockIdx.z / gqa_ratio);
Q += nb03*sequence + nb02* head + nb01*ic0;
K += nb13*sequence + nb12*(head / gqa_ratio);
V += nb23*sequence + nb22*(head / gqa_ratio);
const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
const half slopeh = __float2half(slopef);
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
@ -330,12 +334,11 @@ static __global__ void flash_attn_vec_ext_f16(
if (gridDim.y == 1) {
dst_val /= kqsum[j_VKQ];
}
const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
dst[j_dst*D*gridDim.z + D*blockIdx.z + tid] = dst_val;
dst[(((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*D + tid] = dst_val;
}
if (gridDim.y != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
dst_meta[((ic0 + tid)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
}
#else
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
@ -344,8 +347,8 @@ static __global__ void flash_attn_vec_ext_f16(
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne32);
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);

View file

@ -28,8 +28,10 @@ static __global__ void flash_attn_vec_ext_f32(
const int ne13,
const int ne31,
const int ne32,
const int ne33,
const int nb31,
const int nb32,
const int nb33,
const int nb01,
const int nb02,
const int nb03,
@ -53,8 +55,8 @@ static __global__ void flash_attn_vec_ext_f32(
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
@ -77,14 +79,16 @@ static __global__ void flash_attn_vec_ext_f32(
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
const int sequence = blockIdx.z / ne02;
const int head = blockIdx.z - sequence*ne02;
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
Q += nb02* blockIdx.z + nb01*ic0;
K += nb12*(blockIdx.z / gqa_ratio);
V += nb22*(blockIdx.z / gqa_ratio); // K and V have same shape
Q += nb03*sequence + nb02* head + nb01*ic0;
K += nb13*sequence + nb12*(head / gqa_ratio);
V += nb23*sequence + nb22*(head / gqa_ratio);
const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
const float slope = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
constexpr int nwarps = D / WARP_SIZE;
@ -326,12 +330,11 @@ static __global__ void flash_attn_vec_ext_f32(
if (gridDim.y == 1) {
dst_val /= kqsum[j_VKQ];
}
const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
dst[j_dst*D*gridDim.z + D*blockIdx.z + tid] = dst_val;
dst[(((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*D + tid] = dst_val;
}
if (gridDim.y != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
dst_meta[((ic0 + tid)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
}
#else
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
@ -340,8 +343,8 @@ static __global__ void flash_attn_vec_ext_f32(
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
GGML_UNUSED(ne31); GGML_UNUSED(ne32);
GGML_UNUSED(nb31); GGML_UNUSED(nb32);
GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33);
GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);

View file

@ -47,8 +47,10 @@ static __global__ void flash_attn_ext_f16(
const int ne13,
const int ne31,
const int ne32,
const int ne33,
const int nb31,
const int nb32,
const int nb33,
const int nb01,
const int nb02,
const int nb03,
@ -95,17 +97,19 @@ static __global__ void flash_attn_ext_f16(
constexpr int kqs_padded = FATTN_KQ_STRIDE + 8;
constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
const int sequence = blockIdx.z / ne02;
const int head = blockIdx.z - sequence*ne02;
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
const float * Q_f = (const float *) (Q + nb02* blockIdx.z + nb01*ic0);
const half * K_h = (const half *) (K + nb12*(blockIdx.z / gqa_ratio));
const half * V_h = (const half *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
const float * Q_f = (const float *) (Q + nb03* sequence + nb02* head + nb01*ic0);
const half * K_h = (const half *) (K + nb13* sequence + nb12*(head / gqa_ratio));
const half * V_h = (const half *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
const half2 * mask2 = (const half2 *) maskh;
const int stride_Q = nb01 / sizeof(float);
const int stride_KV = nb11 / sizeof(half);
const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
const half slopeh = __float2half(slopef);
const half2 slope2 = make_half2(slopef, slopef);
@ -400,7 +404,6 @@ static __global__ void flash_attn_ext_f16(
if (ic0 + j_VKQ >= ne01) {
return;
}
const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
float KQ_rowsum_j;
if (std::is_same<KQ_acc_t, float>::value) {
@ -409,6 +412,8 @@ static __global__ void flash_attn_ext_f16(
KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]);
}
const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
#pragma unroll
for (int i0 = 0; i0 < D; i0 += warp_size) {
const int i = i0 + threadIdx.x;
@ -419,7 +424,7 @@ static __global__ void flash_attn_ext_f16(
if (gridDim.y == 1) {
dst_val /= KQ_rowsum_j;
}
dst[j_dst*gridDim.z*D + blockIdx.z*D + i] = dst_val;
dst[j_dst_unrolled*D + i] = dst_val;
}
if (gridDim.y == 1 || threadIdx.x != 0) {
@ -433,7 +438,7 @@ static __global__ void flash_attn_ext_f16(
dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]);
}
dst_meta_val.y = KQ_rowsum_j;
dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = dst_meta_val;
dst_meta[j_dst_unrolled] = dst_meta_val;
}
#else
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
@ -442,7 +447,8 @@ static __global__ void flash_attn_ext_f16(
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33); GGML_UNUSED(nb31);
GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3);

View file

@ -3418,12 +3418,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
if (op->src[0]->ne[0] == 192) {
return false;
}
// TODO: support broadcast
// note: this was initially implemented in https://github.com/ggml-org/llama.cpp/pull/14500, but
// the interface of ggml_flash_attn_ext() changed in https://github.com/ggml-org/llama.cpp/pull/14505
if (op->src[0]->ne[3] != 1) {
return false;
}
if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) {
return false;
}
@ -3436,6 +3430,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
if (op->src[0]->ne[0] == 256 && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) {
return true;
}
if (op->src[3] && op->src[3]->ne[2] != 1) {
return false;
}
return fp16_mma_available(ggml_cuda_info().devices[dev_ctx->device].cc) &&
op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
}

View file

@ -0,0 +1,54 @@
cmake_minimum_required(VERSION 3.13)
find_package(Python3 REQUIRED)
# Shader locations
set(SHADER_DIR "${CMAKE_CURRENT_SOURCE_DIR}/wgsl-shaders")
set(SHADER_OUTPUT_DIR "${CMAKE_CURRENT_BINARY_DIR}/generated")
set(SHADER_HEADER "${SHADER_OUTPUT_DIR}/ggml-wgsl-shaders.hpp")
file(MAKE_DIRECTORY ${SHADER_OUTPUT_DIR})
message(STATUS "Shader output dir: ${SHADER_OUTPUT_DIR}")
# Find all WGSL files
file(GLOB WGSL_SHADER_FILES "${SHADER_DIR}/*.wgsl")
# Generate the header using a Python script
add_custom_command(
OUTPUT ${SHADER_HEADER}
COMMAND ${CMAKE_COMMAND} -E echo "Embedding WGSL shaders to ggml-wgsl-shaders.hpp"
COMMAND ${CMAKE_COMMAND} -E make_directory ${SHADER_OUTPUT_DIR}
COMMAND ${CMAKE_COMMAND} -E env PYTHONIOENCODING=utf-8
${Python3_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/wgsl-shaders/embed_wgsl.py
--input "${SHADER_DIR}"
--output "${SHADER_HEADER}"
DEPENDS ${WGSL_SHADER_FILES} ${CMAKE_CURRENT_SOURCE_DIR}/wgsl-shaders/embed_wgsl.py
VERBATIM
)
add_custom_target(generate_shaders DEPENDS ${SHADER_HEADER})
ggml_add_backend_library(ggml-webgpu
ggml-webgpu.cpp
${SHADER_HEADER}
../../include/ggml-webgpu.h
)
add_dependencies(ggml-webgpu generate_shaders)
if(EMSCRIPTEN)
set(EMDAWNWEBGPU_DIR "" CACHE PATH "Path to emdawnwebgpu_pkg")
target_compile_options(ggml-webgpu PRIVATE "--use-port=${EMDAWNWEBGPU_DIR}/emdawnwebgpu.port.py")
target_link_options(ggml-webgpu PRIVATE "--use-port=${EMDAWNWEBGPU_DIR}/emdawnwebgpu.port.py")
else()
find_package(Dawn REQUIRED)
set(DawnWebGPU_TARGET dawn::webgpu_dawn)
endif()
if (GGML_WEBGPU_DEBUG)
target_compile_definitions(ggml-webgpu PRIVATE GGML_WEBGPU_DEBUG=1)
endif()
target_include_directories(ggml-webgpu PRIVATE ${SHADER_OUTPUT_DIR})
target_link_libraries(ggml-webgpu PRIVATE ${DawnWebGPU_TARGET})

View file

@ -0,0 +1,907 @@
#include "ggml-webgpu.h"
#include <webgpu/webgpu_cpp.h>
#include "ggml-impl.h"
#include "ggml-backend-impl.h"
#include "ggml-wgsl-shaders.hpp"
#include <cstring>
#include <iostream>
#include <mutex>
#include <vector>
#ifdef GGML_WEBGPU_DEBUG
#define WEBGPU_LOG_DEBUG(msg) std::cout << msg << std::endl
#else
#define WEBGPU_LOG_DEBUG(msg) ((void) 0)
#endif // GGML_WEBGPU_DEBUG
/* Constants */
#define WEBGPU_MUL_MAT_WG_SIZE 64
#define WEBGPU_MUL_MAT_PARAMS_SIZE (13 * sizeof(uint32_t)) // M, N, K, batch sizes, broadcasts
#define WEBGPU_CPY_PARAMS_SIZE (15 * sizeof(uint32_t)) // strides and offsets
#define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4
/* End Constants */
// This is a "fake" base pointer, since WebGPU buffers do not have pointers to their locations.
static void * const webgpu_ptr_base = (void *)(uintptr_t) 0x1000; // NOLINT
// Always returns the base offset of a tensor, regardless of views.
static uint64_t webgpu_tensor_offset(const ggml_tensor * tensor) {
if (tensor->view_src) {
return (uint8_t *) tensor->view_src->data - (uint8_t *) webgpu_ptr_base;
}
return (uint8_t *) tensor->data - (uint8_t *) webgpu_ptr_base;
}
/* Struct definitions */
// All the base objects needed to run operations on a WebGPU device
struct webgpu_context_struct {
wgpu::Instance instance;
wgpu::Adapter adapter;
wgpu::Device device;
wgpu::Queue queue;
wgpu::Limits limits;
wgpu::SupportedFeatures features;
std::mutex mutex;
bool device_initialized = false;
// pipelines and parameter buffers
// TODO: reuse params buffers for different pipelines when possible
wgpu::ComputePipeline memset_pipeline;
wgpu::Buffer memset_params_dev_buf;
wgpu::Buffer memset_params_host_buf;
wgpu::ComputePipeline mul_mat_pipeline;
wgpu::Buffer mul_mat_params_dev_buf;
wgpu::Buffer mul_mat_params_host_buf;
wgpu::ComputePipeline cpy_pipeline;
wgpu::Buffer cpy_params_dev_buf;
wgpu::Buffer cpy_params_host_buf;
size_t memset_bytes_per_thread;
// Staging buffer for reading data from the GPU
wgpu::Buffer get_tensor_staging_buf;
};
typedef std::shared_ptr<webgpu_context_struct> webgpu_context;
struct ggml_backend_webgpu_reg_context {
webgpu_context webgpu_ctx;
size_t device_count;
const char * name;
};
struct ggml_backend_webgpu_device_context {
webgpu_context webgpu_ctx;
std::string device_name;
std::string device_desc;
};
struct ggml_backend_webgpu_context {
webgpu_context webgpu_ctx;
std::string name;
};
struct ggml_backend_webgpu_buffer_context {
webgpu_context webgpu_ctx;
wgpu::Buffer buffer;
ggml_backend_webgpu_buffer_context(webgpu_context ctx, wgpu::Buffer buf) :
webgpu_ctx(ctx), buffer(buf) {
}
};
/* End struct definitions */
/* WebGPU object initializations */
static void ggml_webgpu_create_pipeline(wgpu::Device &device, wgpu::ComputePipeline &pipeline, const char * shader_code, const char * label, const std::vector<wgpu::ConstantEntry> &constants = {}) {
WEBGPU_LOG_DEBUG("ggml_webgpu_create_pipeline()");
wgpu::ShaderSourceWGSL shader_source;
shader_source.code = shader_code;
wgpu::ShaderModuleDescriptor shader_desc;
shader_desc.nextInChain = &shader_source;
wgpu::ShaderModule shader_module = device.CreateShaderModule(&shader_desc);
wgpu::ComputePipelineDescriptor pipeline_desc;
pipeline_desc.label = label;
pipeline_desc.compute.module = shader_module;
pipeline_desc.compute.entryPoint = "main"; // Entry point in the WGSL code
pipeline_desc.layout = nullptr; // nullptr means auto layout
if (constants.size() > 0) {
pipeline_desc.compute.constants = constants.data();
pipeline_desc.compute.constantCount = constants.size();
}
pipeline = device.CreateComputePipeline(&pipeline_desc);
}
static void ggml_webgpu_create_buffer(wgpu::Device &device, wgpu::Buffer &buffer, size_t size, wgpu::BufferUsage usage, const char* label) {
WEBGPU_LOG_DEBUG("ggml_webgpu_create_buffer()");
wgpu::BufferDescriptor buffer_desc;
buffer_desc.size = size;
buffer_desc.usage = usage;
buffer_desc.label = label;
buffer_desc.mappedAtCreation = false;
// TODO: error handling
buffer = device.CreateBuffer(&buffer_desc);
}
/** End WebGPU object initializations */
/** WebGPU Actions */
static void ggml_backend_webgpu_map_buffer(webgpu_context ctx, wgpu::Buffer buffer, wgpu::MapMode mode, size_t offset, size_t size) {
ctx->instance.WaitAny(buffer.MapAsync(
mode, offset, size, wgpu::CallbackMode::WaitAnyOnly,
[](wgpu::MapAsyncStatus status, wgpu::StringView message) {
if (status != wgpu::MapAsyncStatus::Success) {
GGML_LOG_ERROR("ggml_webgpu: Failed to map buffer: %s\n", message.data);
}
}),
UINT64_MAX
);
}
static void ggml_backend_webgpu_buffer_memset(webgpu_context ctx, wgpu::Buffer buf, uint32_t value, size_t offset, size_t size) {
std::lock_guard<std::mutex> lock(ctx->mutex);
wgpu::Device device = ctx->device;
// map the host parameters buffer
ggml_backend_webgpu_map_buffer(ctx, ctx->memset_params_host_buf, wgpu::MapMode::Write, 0, ctx->memset_params_host_buf.GetSize());
uint32_t * params = (uint32_t *) ctx->memset_params_host_buf.GetMappedRange();
params[0] = (uint32_t)offset;
params[1] = (uint32_t)size;
params[2] = value;
ctx->memset_params_host_buf.Unmap();
wgpu::BindGroupEntry entries[2];
entries[0].binding = 0; // binding for the buffer to memset
entries[0].buffer = buf;
entries[0].offset = 0;
entries[0].size = buf.GetSize();
entries[1].binding = 1; // binding for the parameters
entries[1].buffer = ctx->memset_params_dev_buf;
entries[1].offset = 0;
entries[1].size = ctx->memset_params_dev_buf.GetSize();
wgpu::BindGroupDescriptor bind_group_desc;
bind_group_desc.layout = ctx->memset_pipeline.GetBindGroupLayout(0);
bind_group_desc.entryCount = 2;
bind_group_desc.label = "ggml_memset";
bind_group_desc.entries = entries;
wgpu::BindGroup bind_group = device.CreateBindGroup(&bind_group_desc);
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
encoder.CopyBufferToBuffer(
ctx->memset_params_host_buf, 0,
ctx->memset_params_dev_buf, 0,
ctx->memset_params_dev_buf.GetSize()
);
wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
pass.SetPipeline(ctx->memset_pipeline);
pass.SetBindGroup(0, bind_group);
size_t bytes_per_wg = ctx->limits.maxComputeWorkgroupSizeX * ctx->memset_bytes_per_thread;
pass.DispatchWorkgroups(((size + 3) + bytes_per_wg - 1) / bytes_per_wg, 1, 1);
pass.End();
wgpu::CommandBuffer commands = encoder.Finish();
ctx->queue.Submit(1, &commands);
}
static void ggml_backend_webgpu_wait_on_submission(webgpu_context ctx) {
// Wait for the queue to finish processing all commands
ctx->instance.WaitAny(ctx->queue.OnSubmittedWorkDone(wgpu::CallbackMode::WaitAnyOnly,
[](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
if (status != wgpu::QueueWorkDoneStatus::Success) {
GGML_LOG_ERROR("ggml_webgpu: Failed to wait on queue: %s\n", message.data);
}
}),
UINT64_MAX
);
}
/** End WebGPU Actions */
/** GGML Backend Interface */
static const char * ggml_backend_webgpu_name(ggml_backend_t backend) {
ggml_backend_webgpu_context * ctx = (ggml_backend_webgpu_context *)backend->context;
return ctx->name.c_str();
}
static void ggml_backend_webgpu_free(ggml_backend_t backend) {
ggml_backend_webgpu_context * ctx = (ggml_backend_webgpu_context *)backend->context;
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_free(" << ctx->name << ")");
// TODO: cleanup
GGML_UNUSED(ctx);
}
// Returns true if node has enqueued work into the queue, false otherwise
static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node){
if (ggml_is_empty(node)) {
return false;
}
WEBGPU_LOG_DEBUG("ggml_webgpu_encode_node(" << node << ", " << ggml_op_name(node->op) << ")");
switch (node->op) {
// no-ops
case GGML_OP_NONE:
case GGML_OP_VIEW:
case GGML_OP_PERMUTE:
return false;
case GGML_OP_CPY: {
std::lock_guard<std::mutex> lock(ctx->mutex);
const ggml_tensor * src = node->src[0];
ggml_backend_webgpu_buffer_context * src_ctx = (ggml_backend_webgpu_buffer_context *) src->buffer->context;
size_t src_offset = webgpu_tensor_offset(src) + src->view_offs;
// assumes power of 2 offset alignment
size_t src_misalignment = src_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1);
// align to minimum offset alignment
src_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1);
ggml_backend_webgpu_buffer_context * dst_ctx = (ggml_backend_webgpu_buffer_context *) node->buffer->context;
size_t dst_offset = webgpu_tensor_offset(node) + node->view_offs;
size_t dst_misalignment = dst_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1);
dst_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1);
wgpu::Device device = ctx->device;
ggml_backend_webgpu_map_buffer(ctx, ctx->cpy_params_host_buf,
wgpu::MapMode::Write, 0, ctx->cpy_params_host_buf.GetSize());
uint32_t * params = (uint32_t *) ctx->cpy_params_host_buf.GetMappedRange();
uint32_t ne = (uint32_t)ggml_nelements(node);
params[0] = ne;
params[1] = src_misalignment/ggml_type_size(src->type);
params[2] = dst_misalignment/ggml_type_size(node->type);
// Convert byte-strides to element-strides
params[3] = (uint32_t)src->nb[0]/ggml_type_size(src->type);
params[4] = (uint32_t)src->nb[1]/ggml_type_size(src->type);
params[5] = (uint32_t)src->nb[2]/ggml_type_size(src->type);
params[6] = (uint32_t)src->nb[3]/ggml_type_size(src->type);
params[7] = (uint32_t)node->nb[0]/ggml_type_size(node->type);
params[8] = (uint32_t)node->nb[1]/ggml_type_size(node->type);
params[9] = (uint32_t)node->nb[2]/ggml_type_size(node->type);
params[10] = (uint32_t)node->nb[3]/ggml_type_size(node->type);
// Logical shape — same for both tensors even if permuted
params[11] = (uint32_t)(src->ne[0]);
params[12] = (uint32_t)(src->ne[1]);
params[13] = (uint32_t)(src->ne[2]);
params[14] = (uint32_t)(src->ne[3]);
ctx->cpy_params_host_buf.Unmap();
wgpu::BindGroupEntry entries[3];
entries[0].binding = 0;
entries[0].buffer = src_ctx->buffer;
entries[0].offset = src_offset;
entries[0].size = (ggml_nbytes(src) + src_misalignment + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) & ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1);
entries[1].binding = 1;
entries[1].buffer = dst_ctx->buffer;
entries[1].offset = dst_offset;
entries[1].size = (ggml_nbytes(node) + dst_misalignment + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) & ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1);
entries[2].binding = 2;
entries[2].buffer = ctx->cpy_params_dev_buf;
entries[2].offset = 0;
entries[2].size = ctx->cpy_params_dev_buf.GetSize();
wgpu::BindGroupDescriptor bind_group_desc;
bind_group_desc.layout = ctx->cpy_pipeline.GetBindGroupLayout(0);
bind_group_desc.label = "ggml_op_cpy";
bind_group_desc.entryCount = 3;
bind_group_desc.entries = entries;
wgpu::BindGroup bind_group = device.CreateBindGroup(&bind_group_desc);
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
encoder.CopyBufferToBuffer(
ctx->cpy_params_host_buf, 0,
ctx->cpy_params_dev_buf, 0,
ctx->cpy_params_dev_buf.GetSize()
);
wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
pass.SetPipeline(ctx->cpy_pipeline);
pass.SetBindGroup(0, bind_group);
size_t max_wg_size = ctx->limits.maxComputeWorkgroupSizeX;
pass.DispatchWorkgroups((ne + max_wg_size - 1) / max_wg_size);
pass.End();
wgpu::CommandBuffer commands = encoder.Finish();
// TODO, don't submit here, batch submissions
ctx->queue.Submit(1, &commands);
// TODO, don't wait on submission here
ggml_backend_webgpu_wait_on_submission(ctx);
return true;
}
case GGML_OP_MUL_MAT:
{
const ggml_tensor * src0 = node->src[0];
ggml_backend_webgpu_buffer_context * src0_ctx = (ggml_backend_webgpu_buffer_context *) src0->buffer->context;
size_t src0_offset = webgpu_tensor_offset(src0) + src0->view_offs;
const ggml_tensor * src1 = node->src[1];
ggml_backend_webgpu_buffer_context * src1_ctx = (ggml_backend_webgpu_buffer_context *) src1->buffer->context;
size_t src1_offset = webgpu_tensor_offset(src1) + src1->view_offs;
ggml_backend_webgpu_buffer_context * dst_ctx = (ggml_backend_webgpu_buffer_context *) node->buffer->context;
size_t dst_offset = webgpu_tensor_offset(node) + node->view_offs;
wgpu::Device device = ctx->device;
// map the host parameters buffer
ggml_backend_webgpu_map_buffer(ctx, ctx->mul_mat_params_host_buf,
wgpu::MapMode::Write, 0, ctx->mul_mat_params_host_buf.GetSize());
uint32_t * params = (uint32_t *) ctx->mul_mat_params_host_buf.GetMappedRange();
params[0] = (uint32_t)node->ne[1]; // number of rows in result (M)
params[1] = (uint32_t)node->ne[0]; // number of columns in result (N)
params[2] = (uint32_t)src0->ne[0]; // number of columns in src0/src1 (K)
params[3] = (uint32_t)src0->nb[1]/ggml_type_size(src0->type); // stride (elements) of src0 in dimension 1
params[4] = (uint32_t)src1->nb[1]/ggml_type_size(src1->type); // stride (elements) of src1 in dimension 1
params[5] = (uint32_t)src0->nb[2]/ggml_type_size(src0->type); // stride (elements) of src0 in dimension 2
params[6] = (uint32_t)src1->nb[2]/ggml_type_size(src1->type); // stride (elements) of src1 in dimension 2
params[7] = (uint32_t)src0->nb[3]/ggml_type_size(src0->type); // stride (elements) of src0 in dimension 3
params[8] = (uint32_t)src1->nb[3]/ggml_type_size(src1->type); // stride (elements) of src1 in dimension 3
params[9] = (uint32_t)src0->ne[2]; // batch size in dimension 2
params[10] = (uint32_t)src0->ne[3]; // batch size in dimension 3
params[11] = (uint32_t)(src1->ne[2]/src0->ne[2]); // broadcast in dimension 2
params[12] = (uint32_t)(src1->ne[3]/src0->ne[3]); // broadcast in dimension 3
ctx->mul_mat_params_host_buf.Unmap();
wgpu::BindGroupEntry entries[4];
entries[0].binding = 0;
entries[0].buffer = src0_ctx->buffer;
entries[0].offset = src0_offset;
entries[0].size = ggml_nbytes(src0);
entries[1].binding = 1;
entries[1].buffer = src1_ctx->buffer;
entries[1].offset = src1_offset;
entries[1].size = ggml_nbytes(src1);
entries[2].binding = 2;
entries[2].buffer = dst_ctx->buffer;
entries[2].offset = dst_offset;
entries[2].size = ggml_nbytes(node);
entries[3].binding = 3;
entries[3].buffer = ctx->mul_mat_params_dev_buf;
entries[3].offset = 0;
entries[3].size = ctx->mul_mat_params_dev_buf.GetSize();
wgpu::BindGroupDescriptor bind_group_desc;
bind_group_desc.layout = ctx->mul_mat_pipeline.GetBindGroupLayout(0);
bind_group_desc.entryCount = 4;
bind_group_desc.label = "ggml_op_mul_mat";
bind_group_desc.entries = entries;
wgpu::BindGroup bind_group = device.CreateBindGroup(&bind_group_desc);
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
encoder.CopyBufferToBuffer(
ctx->mul_mat_params_host_buf, 0,
ctx->mul_mat_params_dev_buf, 0,
ctx->mul_mat_params_dev_buf.GetSize()
);
wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
pass.SetPipeline(ctx->mul_mat_pipeline);
pass.SetBindGroup(0, bind_group);
pass.DispatchWorkgroups((node->ne[0] * node->ne[1] * node->ne[2] * node->ne[3] + WEBGPU_MUL_MAT_WG_SIZE - 1) / WEBGPU_MUL_MAT_WG_SIZE);
pass.End();
wgpu::CommandBuffer commands = encoder.Finish();
// TODO, don't submit here, batch submissions
ctx->queue.Submit(1, &commands);
// TODO, don't wait on submission here
ggml_backend_webgpu_wait_on_submission(ctx);
return true;
}
default:
return false;
}
}
static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_graph_compute(" << cgraph->n_nodes << " nodes)");
ggml_backend_webgpu_context * backend_ctx = static_cast<ggml_backend_webgpu_context *>(backend->context);
webgpu_context ctx = backend_ctx->webgpu_ctx;
for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_webgpu_encode_node(ctx, cgraph->nodes[i]);
}
return GGML_STATUS_SUCCESS;
}
static ggml_backend_i ggml_backend_webgpu_i = {
/* .get_name = */ ggml_backend_webgpu_name,
/* .free = */ ggml_backend_webgpu_free,
/* .set_tensor_async = */ NULL,
/* .get_tensor_async = */ NULL,
/* .cpy_tensor_async = */ NULL,
/* .synchronize = */ NULL,
/* .graph_plan_create = */ NULL,
/* .graph_plan_free = */ NULL,
/* .graph_plan_update = */ NULL,
/* .graph_plan_compute = */ NULL,
/* .graph_compute = */ ggml_backend_webgpu_graph_compute,
/* .event_record = */ NULL,
/* .event_wait = */ NULL,
};
/* End GGML Backend Interface */
/* GGML Backend Buffer Interface */
static void ggml_backend_webgpu_buffer_free_buffer(ggml_backend_buffer_t buffer) {
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_free_buffer()");
ggml_backend_webgpu_buffer_context * ctx = static_cast<ggml_backend_webgpu_buffer_context *>(buffer->context);
ctx->buffer.Destroy();
}
// Returns the "fake" base pointer.
static void * ggml_backend_webgpu_buffer_get_base(ggml_backend_buffer_t buffer) {
GGML_UNUSED(buffer);
return webgpu_ptr_base;
}
static void ggml_backend_webgpu_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
if (size == 0) {
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor: size is zero, nothing to do.");
return;
}
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor(" << buffer << ", " << tensor << ", " << value << ", " << offset << ", " << size << ")");
ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
// This is a trick to set all bytes of a u32 to the same 1 byte value.
uint32_t val32 = (uint32_t)value * 0x01010101;
ggml_backend_webgpu_buffer_memset(buf_ctx->webgpu_ctx, buf_ctx->buffer, val32, total_offset, size);
}
static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_set_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")");
ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
webgpu_context webgpu_ctx = buf_ctx->webgpu_ctx;
size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
webgpu_ctx->queue.WriteBuffer(buf_ctx->buffer, total_offset, data, (size/4)*4);
if (size % 4 != 0) {
// If size is not a multiple of 4, we need to memset the remaining bytes
size_t remaining_size = size % 4;
// pack the remaining bytes into a uint32_t
uint32_t val32 = 0;
for (size_t i = 0; i < remaining_size; i++) {
((uint8_t *)&val32)[i] = ((const uint8_t *)data)[size - remaining_size + i];
}
// memset the remaining bytes
ggml_backend_webgpu_buffer_memset(webgpu_ctx, buf_ctx->buffer, val32, total_offset + (size - remaining_size), remaining_size);
}
}
static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_get_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")");
ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
webgpu_context webgpu_ctx = buf_ctx->webgpu_ctx;
wgpu::Device device = webgpu_ctx->device;
size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
size_t final_size = size;
if (size % 4 != 0) {
// If size is not a multiple of 4, we need to round it up to the next multiple of 4
final_size = size + (4 - (size % 4));
}
std::lock_guard<std::mutex> lock(webgpu_ctx->mutex);
if (webgpu_ctx->get_tensor_staging_buf == nullptr ||
webgpu_ctx->get_tensor_staging_buf.GetSize() < final_size) {
// Create a new staging buffer if it doesn't exist or is too small
if (webgpu_ctx->get_tensor_staging_buf) {
webgpu_ctx->get_tensor_staging_buf.Destroy();
}
ggml_webgpu_create_buffer(device, webgpu_ctx->get_tensor_staging_buf, final_size,
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "get_tensor_staging_buf");
}
// Copy the data from the buffer to the staging buffer
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
encoder.CopyBufferToBuffer(buf_ctx->buffer, total_offset, webgpu_ctx->get_tensor_staging_buf, 0, final_size);
wgpu::CommandBuffer commands = encoder.Finish();
// Submit the command buffer to the queue
webgpu_ctx->queue.Submit(1, &commands);
// Map the staging buffer to read the data
ggml_backend_webgpu_map_buffer(webgpu_ctx, webgpu_ctx->get_tensor_staging_buf, wgpu::MapMode::Read, 0, final_size);
// Must specify size here since the staging buffer might be larger than the tensor size
const void * mapped_range = webgpu_ctx->get_tensor_staging_buf.GetConstMappedRange(0, final_size);
// Copy the data from the mapped range to the output buffer
std::memcpy(data, mapped_range, size);
webgpu_ctx->get_tensor_staging_buf.Unmap();
}
static void ggml_backend_webgpu_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_clear(" << buffer << ", " << (uint32_t) value << ")");
ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
ggml_backend_webgpu_buffer_memset(buf_ctx->webgpu_ctx, buf_ctx->buffer, value, 0, buffer->size);
}
static ggml_backend_buffer_i ggml_backend_webgpu_buffer_interface = {
/* .free_buffer = */ ggml_backend_webgpu_buffer_free_buffer,
/* .get_base = */ ggml_backend_webgpu_buffer_get_base,
/* .init_tensor = */ NULL, // TODO: optional, needed?
/* .memset_tensor = */ ggml_backend_webgpu_buffer_memset_tensor,
/* .set_tensor = */ ggml_backend_webgpu_buffer_set_tensor,
/* .get_tensor = */ ggml_backend_webgpu_buffer_get_tensor,
/* .cpy_tensor = */ NULL, // TODO: optional, implement this
/* .clear = */ ggml_backend_webgpu_buffer_clear,
/* .reset = */ NULL, // TODO: optional, think it coordinates with .init_tensor
};
/* End GGML Backend Buffer Interface */
/* GGML Backend Buffer Type Interface */
static const char * ggml_backend_webgpu_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
return ctx->device_name.c_str();
}
static ggml_backend_buffer_t ggml_backend_webgpu_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_type_alloc_buffer(" << size << ")");
ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
wgpu::Buffer buf;
ggml_webgpu_create_buffer(ctx->webgpu_ctx->device, buf, size,
wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst, "allocated_buffer");
ggml_backend_webgpu_buffer_context * buf_ctx = new ggml_backend_webgpu_buffer_context(ctx->webgpu_ctx, buf);
return ggml_backend_buffer_init(buft, ggml_backend_webgpu_buffer_interface, buf_ctx, size);
}
static size_t ggml_backend_webgpu_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
return ctx->webgpu_ctx->limits.minStorageBufferOffsetAlignment;
}
// maxBufferSize might be larger, but you can't bind more than maxStorageBufferBindingSize to a single binding.
static size_t ggml_backend_webgpu_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
return ctx->webgpu_ctx->limits.maxStorageBufferBindingSize;
}
/* End GGML Backend Buffer Type Interface */
/* GGML Backend Device Interface */
static const char * ggml_backend_webgpu_device_get_name(ggml_backend_dev_t dev) {
ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
return ctx->device_name.c_str();
}
static const char * ggml_backend_webgpu_device_get_description(ggml_backend_dev_t dev) {
ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
return ctx->device_desc.c_str();
}
static void ggml_backend_webgpu_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
// TODO: what do we actually want to return here? maxBufferSize might not be the full available memory.
*free = ctx->webgpu_ctx->limits.maxBufferSize;
*total = ctx->webgpu_ctx->limits.maxBufferSize;
}
static enum ggml_backend_dev_type ggml_backend_webgpu_device_get_type(ggml_backend_dev_t dev) {
GGML_UNUSED(dev);
return GGML_BACKEND_DEVICE_TYPE_GPU;
}
static void ggml_backend_webgpu_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
props->name = ggml_backend_webgpu_device_get_name(dev);
props->description = ggml_backend_webgpu_device_get_description(dev);
props->type = ggml_backend_webgpu_device_get_type(dev);
ggml_backend_webgpu_device_get_memory(dev, &props->memory_free, &props->memory_total);
props->caps = {
/* .async = */ false,
/* .host_buffer = */ false,
/* .buffer_from_host_ptr = */ false,
/* .events = */ false,
};
}
static ggml_guid_t ggml_backend_webgpu_guid(void) {
static const char * guid_str = "__ggml_webgpu :)";
return reinterpret_cast<ggml_guid_t>((void *)guid_str);
}
static void ggml_webgpu_init_memset_pipeline(webgpu_context webgpu_ctx) {
// we use the maximum workgroup size for the memset pipeline
size_t max_wg_size = webgpu_ctx->limits.maxComputeWorkgroupSizeX;
size_t max_threads = max_wg_size * webgpu_ctx->limits.maxComputeWorkgroupsPerDimension;
// Size the bytes_per_thread so that the largest buffer size can be handled
webgpu_ctx->memset_bytes_per_thread = (webgpu_ctx->limits.maxStorageBufferBindingSize + max_threads - 1) / max_threads;
std::vector<wgpu::ConstantEntry> constants(2);
constants[0].key = "wg_size";
constants[0].value = max_wg_size;
constants[1].key = "bytes_per_thread";
constants[1].value = webgpu_ctx->memset_bytes_per_thread;
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->memset_pipeline, wgsl_memset, "memset", constants);
ggml_webgpu_create_buffer(webgpu_ctx->device, webgpu_ctx->memset_params_dev_buf,
3 * sizeof(uint32_t), // 3 parameters: buffer size, offset, value
wgpu::BufferUsage::Uniform | wgpu::BufferUsage::CopyDst, "memset_params_dev_buf");
ggml_webgpu_create_buffer(webgpu_ctx->device, webgpu_ctx->memset_params_host_buf,
3 * sizeof(uint32_t), wgpu::BufferUsage::MapWrite | wgpu::BufferUsage::CopySrc, "memset_params_host_buf");
}
static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context webgpu_ctx) {
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline, wgsl_mul_mat, "mul_mat");
ggml_webgpu_create_buffer(webgpu_ctx->device, webgpu_ctx->mul_mat_params_dev_buf, WEBGPU_MUL_MAT_PARAMS_SIZE,
wgpu::BufferUsage::Uniform | wgpu::BufferUsage::CopyDst, "mul_mat_params_dev_buf");
ggml_webgpu_create_buffer(webgpu_ctx->device, webgpu_ctx->mul_mat_params_host_buf, WEBGPU_MUL_MAT_PARAMS_SIZE,
wgpu::BufferUsage::MapWrite | wgpu::BufferUsage::CopySrc, "mul_mat_params_host_buf");
}
static void ggml_webgpu_init_cpy_pipeline(webgpu_context webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants(1);
constants[0].key = "wg_size";
constants[0].value = webgpu_ctx->limits.maxComputeWorkgroupSizeX;
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline, wgsl_cpy, "cpy", constants);
ggml_webgpu_create_buffer(webgpu_ctx->device, webgpu_ctx->cpy_params_dev_buf, WEBGPU_CPY_PARAMS_SIZE,
wgpu::BufferUsage::Uniform | wgpu::BufferUsage::CopyDst, "cpy_params_dev_buf");
ggml_webgpu_create_buffer(webgpu_ctx->device, webgpu_ctx->cpy_params_host_buf, WEBGPU_CPY_PARAMS_SIZE,
wgpu::BufferUsage::MapWrite | wgpu::BufferUsage::CopySrc, "cpy_params_host_buf");
}
// TODO: Make thread safe if multiple devices are used
static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, const char * params) {
GGML_UNUSED(params);
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_device_init()");
ggml_backend_webgpu_device_context * dev_ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
webgpu_context webgpu_ctx = dev_ctx->webgpu_ctx;
std::lock_guard<std::mutex> lock(webgpu_ctx->mutex);
if (!webgpu_ctx->device_initialized) {
// Initialize device
wgpu::DeviceDescriptor dev_desc;
dev_desc.requiredLimits = &webgpu_ctx->limits;
dev_desc.requiredFeatures = webgpu_ctx->features.features;
dev_desc.requiredFeatureCount = webgpu_ctx->features.featureCount;
dev_desc.SetDeviceLostCallback(wgpu::CallbackMode::AllowSpontaneous,
[](const wgpu::Device& device, wgpu::DeviceLostReason reason, wgpu::StringView message) {
GGML_UNUSED(device);
GGML_LOG_ERROR("ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast<int>(reason), message.data);
});
dev_desc.SetUncapturedErrorCallback(
[](const wgpu::Device& device, wgpu::ErrorType reason, wgpu::StringView message) {
GGML_UNUSED(device);
GGML_LOG_ERROR("ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast<int>(reason), message.data);
});
webgpu_ctx->instance.WaitAny(webgpu_ctx->adapter.RequestDevice(&dev_desc, wgpu::CallbackMode::WaitAnyOnly,
[webgpu_ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) {
if (status != wgpu::RequestDeviceStatus::Success) {
GGML_LOG_ERROR("ggml_webgpu: Failed to get a device: %s\n", message.data);
return;
}
webgpu_ctx->device = device;
}),
UINT64_MAX
);
GGML_ASSERT(webgpu_ctx->device != nullptr);
// Initialize (compute) queue
webgpu_ctx->queue = webgpu_ctx->device.GetQueue();
ggml_webgpu_init_memset_pipeline(webgpu_ctx);
ggml_webgpu_init_mul_mat_pipeline(webgpu_ctx);
ggml_webgpu_init_cpy_pipeline(webgpu_ctx);
webgpu_ctx->device_initialized = true;
}
static ggml_backend_webgpu_context backend_ctx;
backend_ctx.name = GGML_WEBGPU_NAME + std::string(": ") + dev_ctx->device_name;
backend_ctx.webgpu_ctx = webgpu_ctx;
// See GGML Backend Interface section
static ggml_backend backend = {
/* .guid = */ ggml_backend_webgpu_guid(),
/* .interface = */ ggml_backend_webgpu_i,
/* .device = */ dev,
/* .context = */ &backend_ctx,
};
return &backend;
}
static ggml_backend_buffer_type_t ggml_backend_webgpu_device_get_buffer_type(ggml_backend_dev_t dev) {
// See GGML Backend Buffer Type Interface section
static struct ggml_backend_buffer_type ggml_backend_webgpu_buffer_type = {
/* .iface = */ {
/* .get_name = */ ggml_backend_webgpu_buffer_type_get_name,
/* .alloc_buffer = */ ggml_backend_webgpu_buffer_type_alloc_buffer,
/* .get_alignment = */ ggml_backend_webgpu_buffer_type_get_alignment,
/* .get_max_size = */ ggml_backend_webgpu_buffer_type_get_max_size,
/* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
/* .is_host = */ NULL, // defaults to false
},
/* .device = */ dev,
/* .context = */ NULL,
};
return &ggml_backend_webgpu_buffer_type;
}
static bool ggml_backend_webgpu_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
GGML_UNUSED(dev);
return buft->iface.get_name == ggml_backend_webgpu_buffer_type_get_name;
}
static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
GGML_UNUSED(dev);
switch (op->op) {
case GGML_OP_NONE:
case GGML_OP_VIEW:
case GGML_OP_PERMUTE:
return true;
case GGML_OP_CPY:
return op->type == GGML_TYPE_F16 && op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_MUL_MAT:
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
default:
return false;
}
}
static struct ggml_backend_device_i ggml_backend_webgpu_device_i = {
/* .get_name = */ ggml_backend_webgpu_device_get_name,
/* .get_description = */ ggml_backend_webgpu_device_get_description,
/* .get_memory = */ ggml_backend_webgpu_device_get_memory,
/* .get_type = */ ggml_backend_webgpu_device_get_type,
/* .get_props = */ ggml_backend_webgpu_device_get_props,
/* .init_backend = */ ggml_backend_webgpu_device_init,
/* .get_buffer_type = */ ggml_backend_webgpu_device_get_buffer_type,
/* .get_host_buffer_type = */ NULL,
/* .buffer_from_host_ptr = */ NULL,
/* .supports_op = */ ggml_backend_webgpu_device_supports_op,
/* .supports_buft = */ ggml_backend_webgpu_device_supports_buft,
/* .offload_op = */ NULL,
/* .event_new = */ NULL,
/* .event_free = */ NULL,
/* .event_synchronize = */ NULL,
};
/* End GGML Backend Device Interface */
/* GGML Backend Registration Interface */
static const char * ggml_backend_webgpu_reg_get_name(ggml_backend_reg_t reg) {
ggml_backend_webgpu_reg_context * ctx = static_cast<ggml_backend_webgpu_reg_context *>(reg->context);
return ctx->name;
}
static size_t ggml_backend_webgpu_reg_get_device_count(ggml_backend_reg_t reg) {
ggml_backend_webgpu_reg_context * ctx = static_cast<ggml_backend_webgpu_reg_context *>(reg->context);
return ctx->device_count;
}
// TODO: Does this need to be thread safe? Is it only called once?
// Only one device is supported for now
static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t reg, size_t index) {
GGML_ASSERT(index == 0);
WEBGPU_LOG_DEBUG("ggml_backend_reg_get_device()");
ggml_backend_webgpu_reg_context * reg_ctx = static_cast<ggml_backend_webgpu_reg_context *>(reg->context);
webgpu_context ctx = reg_ctx->webgpu_ctx;
wgpu::RequestAdapterOptions options = {};
auto callback = [](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char *message, void *userdata) {
if (status != wgpu::RequestAdapterStatus::Success) {
GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message);
return;
}
*static_cast<wgpu::Adapter *>(userdata) = adapter;
};
void *userdata = &ctx->adapter;
ctx->instance.WaitAny(ctx->instance.RequestAdapter(&options, wgpu::CallbackMode::WaitAnyOnly, callback, userdata), UINT64_MAX);
GGML_ASSERT(ctx->adapter != nullptr);
ctx->adapter.GetLimits(&ctx->limits);
ctx->adapter.GetFeatures(&ctx->features);
wgpu::AdapterInfo info{};
ctx->adapter.GetInfo(&info);
static ggml_backend_webgpu_device_context device_ctx;
device_ctx.webgpu_ctx = ctx;
device_ctx.device_name = GGML_WEBGPU_NAME;
device_ctx.device_desc = std::string(info.description.data);
GGML_LOG_INFO("ggml_webgpu: adapter_info: vendor_id: %u | vendor: %s | architecture: %s | device_id: %u | name: %s | device_desc: %s\n",
info.vendorID, info.vendor.data, info.architecture.data, info.deviceID, info.device.data, info.description.data);
// See GGML Backend Device Interface section
static ggml_backend_device device = {
/* .iface = */ ggml_backend_webgpu_device_i,
/* .reg = */ reg,
/* .context = */ &device_ctx,
};
return &device;
}
static const struct ggml_backend_reg_i ggml_backend_webgpu_reg_i = {
/* .get_name = */ ggml_backend_webgpu_reg_get_name,
/* .get_device_count = */ ggml_backend_webgpu_reg_get_device_count,
/* .get_device = */ ggml_backend_webgpu_reg_get_device,
/* .get_proc_address = */ NULL,
};
/* End GGML Backend Registration Interface */
// TODO: Does this need to be thread safe? Is it only called once?
ggml_backend_reg_t ggml_backend_webgpu_reg() {
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_reg()");
webgpu_context webgpu_ctx = std::make_shared<webgpu_context_struct>();
webgpu_ctx->device_initialized = false;
static ggml_backend_webgpu_reg_context ctx;
ctx.webgpu_ctx = webgpu_ctx;
ctx.name = GGML_WEBGPU_NAME;
ctx.device_count = 1;
wgpu::InstanceDescriptor instance_descriptor{};
std::vector<wgpu::InstanceFeatureName> instance_features = {wgpu::InstanceFeatureName::TimedWaitAny};
instance_descriptor.requiredFeatures = instance_features.data();
instance_descriptor.requiredFeatureCount = instance_features.size();
webgpu_ctx->instance = wgpu::CreateInstance(&instance_descriptor);
GGML_ASSERT(webgpu_ctx->instance != nullptr);
static ggml_backend_reg reg = {
/* .api_version = */ GGML_BACKEND_API_VERSION,
/* .iface = */ ggml_backend_webgpu_reg_i,
/* .context = */ &ctx,
};
return &reg;
}
ggml_backend_t ggml_backend_webgpu_init(void) {
ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_webgpu_reg(), 0);
return ggml_backend_webgpu_device_init(dev, nullptr);
}
GGML_BACKEND_DL_IMPL(ggml_backend_webgpu_reg)

View file

@ -0,0 +1,60 @@
enable f16;
@group(0) @binding(0)
var<storage, read_write> src: array<f32>;
@group(0) @binding(1)
var<storage, read_write> dst: array<f16>;
struct Params {
ne: u32, // total number of elements
offset_src: u32, // in elements
offset_dst: u32, // in elements
// Strides (in elements) may be permuted
stride_src0: u32,
stride_src1: u32,
stride_src2: u32,
stride_src3: u32,
stride_dst0: u32,
stride_dst1: u32,
stride_dst2: u32,
stride_dst3: u32,
// Logical shape (same for both tensors)
ne0: u32,
ne1: u32,
ne2: u32,
ne3: u32,
};
@group(0) @binding(2)
var<uniform> params: Params;
override wg_size: u32;
@compute @workgroup_size(wg_size)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x >= params.ne) {
return;
}
var i = gid.x;
let i3 = i / (params.ne2 * params.ne1 * params.ne0);
i = i % (params.ne2 * params.ne1 * params.ne0);
let i2 = i / (params.ne1 * params.ne0);
i = i % (params.ne1 * params.ne0);
let i1 = i / params.ne0;
let i0 = i % params.ne0;
let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 +
i2 * params.stride_src2 + i3 * params.stride_src3;
let dst_idx = i0 * params.stride_dst0 + i1 * params.stride_dst1 +
i2 * params.stride_dst2 + i3 * params.stride_dst3;
dst[params.offset_dst + dst_idx] = f16(src[params.offset_src + src_idx]);
}

View file

@ -0,0 +1,35 @@
import os
import argparse
def escape_triple_quotes(wgsl):
# Simple defense in case of embedded """
return wgsl.replace('"""', '\\"""')
def to_cpp_string_literal(varname, content):
return f'const char* wgsl_{varname} = R"({content})";\n'
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--input', required=True)
parser.add_argument('--output', required=True)
args = parser.parse_args()
with open(args.output, 'w', encoding='utf-8') as out:
out.write("// Auto-generated shader embedding \n\n")
for fname in sorted(os.listdir(args.input)):
if not fname.endswith('.wgsl'):
continue
shader_path = os.path.join(args.input, fname)
varname = os.path.splitext(fname)[0]
with open(shader_path, 'r', encoding='utf-8') as f:
content = f.read()
content = escape_triple_quotes(content)
out.write(to_cpp_string_literal(varname, content))
out.write('\n')
if __name__ == '__main__':
main()

View file

@ -0,0 +1,40 @@
@group(0) @binding(0)
var<storage, read_write> output_buffer: array<u32>;
struct Params {
offset: u32, // in bytes
size: u32, // in bytes
value: u32, // 4 8-bit values, which are either repeating (memset_tensor) or may be separate (cleaning up unaligned set_tensor operations)
};
@group(0) @binding(1)
var<uniform> params: Params;
override wg_size: u32;
override bytes_per_thread: u32;
@compute @workgroup_size(wg_size)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let i = gid.x * bytes_per_thread;
let start = params.offset;
let end = params.offset + params.size;
for (var j: u32 = 0u; j < bytes_per_thread; j = j + 1u) {
let byte_index = start + i + j;
if (byte_index + 4u <= end) {
output_buffer[(byte_index >> 2u)] = params.value;
} else {
// Handle tail (unaligned)
for (var k: u32 = 0u; k < 4u; k = k + 1u) {
let idx = byte_index + k;
if (idx < end) {
let word_idx = idx >> 2u;
let byte_offset = (idx & 3u) * 8u;
let mask = ~(0xffu << byte_offset);
let existing = output_buffer[word_idx];
output_buffer[word_idx] = (existing & mask) | ((params.value & 0xffu) << byte_offset);
}
}
}
}
}

View file

@ -0,0 +1,56 @@
struct MulMatParams {
m: u32,
n: u32,
k: u32,
// all strides are in elements
stride_01: u32,
stride_11: u32,
stride_02: u32,
stride_12: u32,
stride_03: u32,
stride_13: u32,
bs02: u32,
bs03: u32,
broadcast2: u32,
broadcast3: u32
};
@group(0) @binding(0) var<storage, read_write> src0: array<f32>; // N rows, K columns
@group(0) @binding(1) var<storage, read_write> src1: array<f32>; // M rows, K columns (transposed)
@group(0) @binding(2) var<storage, read_write> dst: array<f32>; // M rows, N columns
@group(0) @binding(3) var<uniform> params: MulMatParams;
@compute @workgroup_size(64)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
let total = params.m * params.n * params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3;
if (global_id.x >= total) {
return;
}
let dst2_stride = params.m * params.n;
let dst3_stride = dst2_stride * params.bs02 * params.broadcast2;
let dst3_idx = global_id.x / dst3_stride;
let src03_idx = dst3_idx / params.broadcast3; // src0 may be broadcast along the third dimension
let src13_idx = dst3_idx; // src1 is not broadcast
let dst3_rem = global_id.x % dst3_stride;
let dst2_idx = dst3_rem / dst2_stride;
let src02_idx = dst2_idx / params.broadcast2; // src0 may also be broadcast along the second dimension
let src12_idx = dst2_idx; // src1 is not broadcast
let dst2_rem = dst3_rem % dst2_stride;
let row = dst2_rem / params.n; // output row
let col = dst2_rem % params.n; // output column
var sum = 0.0;
for (var i: u32 = 0u; i < params.k; i = i + 1u) {
let src0_idx = src03_idx * params.stride_03 + src02_idx * params.stride_02 + col * params.stride_01 + i;
let src1_idx = src13_idx * params.stride_13 + src12_idx * params.stride_12 + row * params.stride_11 + i;
sum = sum + src0[src0_idx] * src1[src1_idx];
}
dst[dst3_idx * dst3_stride + dst2_idx * dst2_stride + row * params.n + col] = sum;
}

View file

@ -367,6 +367,7 @@ class MODEL_ARCH(IntEnum):
HUNYUAN_MOE = auto()
SMOLLM3 = auto()
LFM2 = auto()
DREAM = auto()
class VISION_PROJECTOR_TYPE(IntEnum):
@ -683,6 +684,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.HUNYUAN_MOE: "hunyuan-moe",
MODEL_ARCH.SMOLLM3: "smollm3",
MODEL_ARCH.LFM2: "lfm2",
MODEL_ARCH.DREAM: "dream",
}
VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
@ -1289,6 +1291,21 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.DREAM: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.QWEN2VL: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,

View file

@ -338,6 +338,9 @@ extern "C" {
bool swa_full; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
// NOTE: setting to false when n_seq_max > 1 can cause bad performance in some cases
// ref: https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573
bool kv_unified; // use a unified buffer across the input sequences when computing the attention
// try to disable when n_seq_max > 1 for improved performance when the sequences do not share a large prefix
// ref: https://github.com/ggml-org/llama.cpp/pull/14363
};
// model quantization parameters
@ -728,7 +731,7 @@ extern "C" {
// - lazily on next llama_decode()
// p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf)
DEPRECATED(void llama_kv_self_seq_div(
DEPRECATED(LLAMA_API void llama_kv_self_seq_div(
struct llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
@ -1008,6 +1011,7 @@ extern "C" {
LLAMA_API llama_token llama_vocab_sep(const struct llama_vocab * vocab); // sentence separator
LLAMA_API llama_token llama_vocab_nl (const struct llama_vocab * vocab); // next-line
LLAMA_API llama_token llama_vocab_pad(const struct llama_vocab * vocab); // padding
LLAMA_API llama_token llama_vocab_mask(const struct llama_vocab * vocab); // mask
LLAMA_API bool llama_vocab_get_add_bos(const struct llama_vocab * vocab);
LLAMA_API bool llama_vocab_get_add_eos(const struct llama_vocab * vocab);

View file

@ -85,6 +85,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_HUNYUAN_MOE, "hunyuan-moe" },
{ LLM_ARCH_SMOLLM3, "smollm3" },
{ LLM_ARCH_LFM2, "lfm2" },
{ LLM_ARCH_DREAM, "dream" },
{ LLM_ARCH_UNKNOWN, "(unknown)" },
};
@ -1891,6 +1892,23 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
},
},
{
LLM_ARCH_DREAM,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
};
static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
@ -2133,3 +2151,12 @@ bool llm_arch_is_hybrid(const llm_arch & arch) {
return false;
}
}
bool llm_arch_is_diffusion(const llm_arch & arch) {
switch (arch) {
case LLM_ARCH_DREAM:
return true;
default:
return false;
}
}

View file

@ -89,6 +89,7 @@ enum llm_arch {
LLM_ARCH_HUNYUAN_MOE,
LLM_ARCH_SMOLLM3,
LLM_ARCH_LFM2,
LLM_ARCH_DREAM,
LLM_ARCH_UNKNOWN,
};
@ -479,3 +480,4 @@ const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor);
bool llm_arch_is_recurrent(const llm_arch & arch);
bool llm_arch_is_hybrid (const llm_arch & arch);
bool llm_arch_is_diffusion(const llm_arch & arch);

View file

@ -27,6 +27,7 @@ bool llama_batch_allocr::init(
const llama_vocab & vocab,
const llama_memory_i * memory,
uint32_t n_embd,
uint32_t n_seq_max,
bool output_all) {
clear();
@ -40,6 +41,11 @@ bool llama_batch_allocr::init(
// validate input batch
//
if (n_seq_max > LLAMA_MAX_SEQ) {
LLAMA_LOG_ERROR("%s: n_seq_max = %d > %d\n", __func__, n_seq_max, LLAMA_MAX_SEQ);
return false;
}
if (batch.token) {
for (int32_t i = 0; i < batch.n_tokens; ++i) {
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= vocab.n_tokens()) {
@ -52,8 +58,8 @@ bool llama_batch_allocr::init(
if (batch.seq_id) {
for (int32_t i = 0; i < batch.n_tokens; ++i) {
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= LLAMA_MAX_SEQ)) {
LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], LLAMA_MAX_SEQ);
if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= (llama_seq_id) n_seq_max)) {
LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], (llama_seq_id) n_seq_max);
return false;
}
}
@ -86,7 +92,7 @@ bool llama_batch_allocr::init(
// initialize the starting position for each sequence based on the positions in the memory
llama_pos p0[LLAMA_MAX_SEQ];
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
for (uint32_t s = 0; s < n_seq_max; ++s) {
if (!memory) {
// if no memory -> start from 0
p0[s] = 0;
@ -143,7 +149,8 @@ bool llama_batch_allocr::init(
// compute stats
//
this->n_embd = n_embd;
this->n_embd = n_embd;
this->n_seq_max = n_seq_max;
// count the outputs in this batch
for (int32_t i = 0; i < batch.n_tokens; ++i) {
@ -189,7 +196,7 @@ bool llama_batch_allocr::init(
seq_set_map[cur].push_back(i);
}
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
for (uint32_t s = 0; s < n_seq_max; ++s) {
if (seq_set_unq.test(s)) {
seq_idx[s] = seq_id_unq.size();
seq_id_unq.push_back(s);
@ -241,7 +248,7 @@ bool llama_batch_allocr::init(
// consistency checks
//
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
for (uint32_t s = 0; s < n_seq_max; ++s) {
if (seq_pos[s].empty()) {
continue;
}
@ -283,8 +290,8 @@ bool llama_batch_allocr::init(
}
if (memory) {
for (int32_t s0 = 0; s0 < LLAMA_MAX_SEQ; ++s0) {
for (int32_t s1 = 0; s1 < LLAMA_MAX_SEQ; ++s1) {
for (uint32_t s0 = 0; s0 < n_seq_max; ++s0) {
for (uint32_t s1 = 0; s1 < n_seq_max; ++s1) {
if (seq_cpl[s0][s1]) {
if (memory->seq_pos_min(s0) != memory->seq_pos_min(s1) ||
memory->seq_pos_max(s0) != memory->seq_pos_max(s1)) {
@ -315,12 +322,12 @@ bool llama_batch_allocr::init(
//
{
seq_set_t cur_seq_set[LLAMA_MAX_SEQ];
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
for (uint32_t s = 0; s < n_seq_max; ++s) {
cur_seq_set[s].set();
}
llama_pos cur_seq_pos[LLAMA_MAX_SEQ];
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
for (uint32_t s = 0; s < n_seq_max; ++s) {
cur_seq_pos[s] = -1;
}
@ -691,7 +698,7 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
}
}
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
for (uint32_t s = 0; s < n_seq_max; ++s) {
if (seq_set_unq.test(s)) {
ubatch.seq_idx[s] = ubatch.seq_id_unq.size();
ubatch.seq_id_unq.push_back(s);

View file

@ -48,6 +48,7 @@ public:
const llama_vocab & vocab,
const llama_memory_i * memory,
uint32_t n_embd,
uint32_t n_seq_max,
bool output_all);
const llama_batch & get_batch() const;
@ -100,6 +101,7 @@ private:
const uint32_t n_pos_per_embd;
uint32_t n_embd;
uint32_t n_seq_max;
uint32_t n_outputs;
std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id

View file

@ -98,10 +98,20 @@ llama_context::llama_context(
LLAMA_LOG_WARN("%s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n", __func__, GGML_KQ_MASK_PAD);
cparams.n_batch = GGML_KQ_MASK_PAD;
}
cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
cparams.op_offload = params.op_offload;
cparams.kv_unified = params.kv_unified;
{
const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS");
const bool supports_set_rows = LLAMA_SET_ROWS ? atoi(LLAMA_SET_ROWS) : 0;
if (!supports_set_rows && !cparams.kv_unified) {
LLAMA_LOG_WARN("%s: non-unified KV cache requires ggml_set_rows() - forcing unified KV cache\n", __func__);
cparams.kv_unified = true;
}
}
const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
@ -112,6 +122,7 @@ llama_context::llama_context(
LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn);
LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn);
LLAMA_LOG_INFO("%s: kv_unified = %s\n", __func__, cparams.kv_unified ? "true" : "false");
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);
@ -267,7 +278,7 @@ llama_context::llama_context(
// reserve worst-case graph
if (!hparams.vocab_only && memory) {
const uint32_t n_seqs = cparams.n_seq_max;
const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
@ -300,7 +311,7 @@ llama_context::llama_context(
// reserve with tg graph to get the number of splits and nodes
{
auto * gf = graph_reserve(1, 1, 1, mctx.get());
auto * gf = graph_reserve(n_seqs, n_seqs, n_seqs, mctx.get());
if (!gf) {
throw std::runtime_error("failed to allocate compute tg buffers");
}
@ -311,6 +322,10 @@ llama_context::llama_context(
// reserve again with pp graph to avoid ggml-alloc reallocations during inference
{
// TODO: not sure if the following graph would be worster case for multi-stream KV caches:
//
// auto * gf = graph_reserve(n_tokens, 1, n_tokens, mctx.get());
//
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
if (!gf) {
throw std::runtime_error("failed to allocate compute pp buffers");
@ -475,7 +490,7 @@ bool llama_context::kv_self_update(bool optimize) {
throw std::runtime_error("failed to initialize memory context");
}
const uint32_t n_seqs = cparams.n_seq_max;
const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
@ -735,13 +750,15 @@ int llama_context::encode(const llama_batch & batch_inp) {
const int32_t n_vocab = model.vocab.n_tokens();
// note: during encode, we always pass the full sequence starting from pos = 0
if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, true)) {
if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) {
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
return -1;
}
const uint32_t n_tokens = balloc->get_n_tokens();
// [TAG_NO_CACHE_PAD]
// TODO: add new split mode where we pad the input sequences so that ubatch.equal_seqs == true
const llama_ubatch ubatch = balloc->split_simple(n_tokens);
// micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
@ -910,7 +927,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
// when computing embeddings, all tokens are output
const bool output_all = cparams.embeddings;
if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, output_all)) {
if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, output_all)) {
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
return -1;
}
@ -2039,7 +2056,7 @@ void llama_context::opt_epoch_iter(
batch.logits [pos_batch] = true;
}
if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd, true)) {
if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) {
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
return;
}
@ -2198,6 +2215,7 @@ llama_context_params llama_context_default_params() {
/*.no_perf =*/ true,
/*.op_offload =*/ true,
/*.swa_full =*/ true,
/*.kv_unified =*/ false,
};
return result;

View file

@ -11,8 +11,8 @@ struct llama_cparams {
uint32_t n_batch;
uint32_t n_ubatch;
uint32_t n_seq_max;
int n_threads; // number of threads to use for generation
int n_threads_batch; // number of threads to use for batch processing
int32_t n_threads; // number of threads to use for generation
int32_t n_threads_batch; // number of threads to use for batch processing
float rope_freq_base;
float rope_freq_scale;
@ -33,6 +33,7 @@ struct llama_cparams {
bool no_perf;
bool warmup;
bool op_offload;
bool kv_unified;
enum llama_pooling_type pooling_type;

View file

@ -982,13 +982,16 @@ ggml_tensor * llm_graph_context::build_attn_mha(
float kq_scale) const {
const bool v_trans = v->nb[1] > v->nb[2];
// split the batch into streams if needed
const auto n_stream = k->ne[3];
q = ggml_reshape_4d(ctx0, q, q->ne[0], q->ne[1], q->ne[2]/n_stream, n_stream);
q = ggml_permute(ctx0, q, 0, 2, 1, 3);
k = ggml_permute(ctx0, k, 0, 2, 1, 3);
v = ggml_permute(ctx0, v, 0, 2, 1, 3);
const auto n_tokens = q->ne[1];
const auto n_head = q->ne[2];
const auto n_kv = k->ne[1];
const auto n_kv = k->ne[1];
ggml_tensor * cur;
@ -1030,7 +1033,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
#endif
}
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
} else {
ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
@ -1075,7 +1078,8 @@ ggml_tensor * llm_graph_context::build_attn_mha(
cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
// recombine streams
cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
if (!cparams.offload_kqv) {
// all nodes between the KV store and the attention output are run on the CPU
@ -1122,6 +1126,10 @@ ggml_tensor * llm_graph_context::build_attn(
const auto & kq_mask = inp->get_kq_mask();
// [TAG_NO_CACHE_PAD]
// TODO: if ubatch.equal_seqs == true, we can split the three tensors below into ubatch.n_seqs_unq streams
assert(ubatch.equal_seqs == false);
ggml_tensor * q = q_cur;
ggml_tensor * k = k_cur;
ggml_tensor * v = v_cur;
@ -1156,13 +1164,14 @@ static std::unique_ptr<llm_graph_input_attn_kv_unified> build_attn_inp_kv_unifie
{
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
const auto n_kv = mctx_cur->get_n_kv();
const auto n_kv = mctx_cur->get_n_kv();
const auto n_tokens = ubatch.n_tokens;
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
ggml_set_input(inp->self_kq_mask);
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
@ -1362,13 +1371,15 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
{
const auto n_kv = mctx_cur->get_base()->get_n_kv();
inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
ggml_set_input(inp->self_kq_mask);
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
@ -1382,7 +1393,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
ggml_set_input(inp->self_kq_mask_swa);
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;

View file

@ -255,10 +255,10 @@ public:
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch, 1, 1]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch, 1, 1]
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
const llama_hparams & hparams;
const llama_cparams & cparams;
@ -289,14 +289,14 @@ public:
ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
ggml_tensor * self_k_idxs_swa = nullptr; // I64 [n_batch]
ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch]
ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch, 1, 1]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch, 1, 1]
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch, 1, 1]
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch, 1, 1]
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
const llama_hparams & hparams;
const llama_cparams & cparams;

View file

@ -65,6 +65,46 @@ uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const {
return n_embd_head_v * n_head_kv;
}
bool llama_hparams::is_n_embd_k_gqa_variable() const {
const uint32_t val = n_embd_k_gqa();
for (uint32_t il = 0; il < n_layer; ++il) {
if (val != n_embd_k_gqa(il)) {
return true;
}
}
return false;
}
bool llama_hparams::is_n_embd_v_gqa_variable() const {
const uint32_t val = n_embd_v_gqa();
for (uint32_t il = 0; il < n_layer; ++il) {
if (val != n_embd_v_gqa(il)) {
return true;
}
}
return false;
}
uint32_t llama_hparams::n_embd_k_gqa_max() const {
uint32_t val = n_embd_k_gqa();
for (uint32_t il = 0; il < n_layer; ++il) {
val = std::max(val, n_embd_k_gqa(il));
}
return val;
}
uint32_t llama_hparams::n_embd_v_gqa_max() const {
uint32_t val = n_embd_v_gqa();
for (uint32_t il = 0; il < n_layer; ++il) {
val = std::max(val, n_embd_v_gqa(il));
}
return val;
}
uint32_t llama_hparams::n_embd_r() const {
if (wkv_head_size != 0) {
// for RWKV models

View file

@ -191,6 +191,14 @@ struct llama_hparams {
// dimension of value embeddings across all k-v heads
uint32_t n_embd_v_gqa(uint32_t il = 0) const;
// true if any layer has a different n_embd_k_gqa/n_embd_v_gqa
bool is_n_embd_k_gqa_variable() const;
bool is_n_embd_v_gqa_variable() const;
// return the maximum n_embd_k_gqa/n_embd_v_gqa across all layers
uint32_t n_embd_k_gqa_max() const;
uint32_t n_embd_v_gqa_max() const;
// dimension of the rolling state embeddings
// corresponds to Mamba's conv_states size or RWKV's token_shift states size
uint32_t n_embd_r() const;

View file

@ -18,16 +18,17 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
bool v_trans,
bool offload,
bool swa_full,
bool unified,
uint32_t kv_size,
uint32_t n_seq_max,
uint32_t n_ubatch,
uint32_t n_pad) : hparams(model.hparams) {
uint32_t n_pad) : hparams(model.hparams), unified(unified) {
llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); };
const uint32_t size_base = kv_size;
uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_ubatch, n_pad));
uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*(unified ? n_seq_max : 1) + n_ubatch, n_pad));
//kcpp: pad the swa kv cache as well, similar to extra_context_handle_fragmentation
size_swa += 32;
@ -45,14 +46,14 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
kv_base = std::make_unique<llama_kv_cache_unified>(
model, std::move(filter_base), type_k, type_v,
v_trans, offload, size_base, n_seq_max, n_pad,
v_trans, offload, unified, size_base, n_seq_max, n_pad,
0, LLAMA_SWA_TYPE_NONE);
LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
kv_swa = std::make_unique<llama_kv_cache_unified>(
model, std::move(filter_swa), type_k, type_v,
v_trans, offload, size_swa, n_seq_max, n_pad,
v_trans, offload, unified, size_swa, n_seq_max, n_pad,
hparams.n_swa, hparams.swa_type);
}
@ -104,6 +105,11 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
// first try simple split
do {
if (!unified) {
// requires equal splits, so we skip the simple split
break;
}
balloc.split_reset();
std::vector<llama_ubatch> ubatches;
@ -144,7 +150,7 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
std::vector<llama_ubatch> ubatches;
while (true) {
auto ubatch = balloc.split_equal(n_ubatch, false);
auto ubatch = balloc.split_equal(n_ubatch, !unified);
if (ubatch.n_tokens == 0) {
break;

View file

@ -20,6 +20,7 @@ public:
bool v_trans,
bool offload,
bool swa_full,
bool unified,
uint32_t kv_size,
uint32_t n_seq_max,
uint32_t n_ubatch,
@ -68,6 +69,8 @@ public:
private:
const llama_hparams & hparams;
const bool unified;
std::unique_ptr<llama_kv_cache_unified> kv_base;
std::unique_ptr<llama_kv_cache_unified> kv_swa;
};

File diff suppressed because it is too large Load diff

View file

@ -35,16 +35,50 @@ public:
std::vector<uint32_t> ids;
};
struct stream_copy_info {
bool empty() const {
assert(ssrc.size() == sdst.size());
return ssrc.empty();
}
std::vector<uint32_t> ssrc;
std::vector<uint32_t> sdst;
};
// for each ubatch, create a slot_info that contains information about where the ubatch should be inserted in the
// KV cells. for example, cell indices for each token, such that: token[i] -> goes to cells[idxs[i]]
struct slot_info {
// data for ggml_set_rows
using idx_vec_t = std::vector<uint32_t>;
idx_vec_t idxs;
// number of streams: ns = s1 - s0 + 1
llama_seq_id s0;
llama_seq_id s1;
std::vector<llama_seq_id> strm; // [ns]
std::vector<idx_vec_t> idxs; // [ns]
uint32_t head() const {
return idxs.at(0);
GGML_ASSERT(idxs.size() == 1);
GGML_ASSERT(!idxs[0].empty());
return idxs[0][0];
}
void resize(size_t n) {
strm.resize(n);
idxs.resize(n);
}
size_t size() const {
GGML_ASSERT(idxs.size() == strm.size());
GGML_ASSERT(!idxs.empty());
return idxs[0].size();
}
size_t n_stream() const {
return strm.size();
}
bool empty() const {
@ -54,9 +88,6 @@ public:
void clear() {
idxs.clear();
}
// TODO: implement
//std::vector<idx_vec_t> seq_idxs;
};
using slot_info_vec_t = std::vector<slot_info>;
@ -68,6 +99,7 @@ public:
ggml_type type_v,
bool v_trans,
bool offload,
bool unified,
uint32_t kv_size,
uint32_t n_seq_max,
uint32_t n_pad,
@ -111,7 +143,8 @@ public:
// llama_kv_cache_unified specific API
//
uint32_t get_size() const;
uint32_t get_size() const;
uint32_t get_n_stream() const;
bool get_has_shift() const;
@ -122,8 +155,8 @@ public:
uint32_t get_n_kv() const;
// get views of the current state of the cache
ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv) const;
ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const;
ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const;
ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const;
// store k_cur and v_cur in the cache based on the provided head location
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const;
@ -137,7 +170,7 @@ public:
// return empty vector on failure
slot_info_vec_t prepare(const std::vector<llama_ubatch> & ubatches);
bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo);
bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo, const stream_copy_info & sc_info);
// find a slot of kv cells that can hold the ubatch
// if cont == true, then the slot must be continuous
@ -157,8 +190,9 @@ public:
void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
void set_input_k_shift(ggml_tensor * dst) const;
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
void set_input_k_shift (ggml_tensor * dst) const;
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
private:
@ -172,15 +206,15 @@ private:
ggml_tensor * k;
ggml_tensor * v;
std::vector<ggml_tensor *> k_stream;
std::vector<ggml_tensor *> v_stream;
};
bool v_trans = true; // the value tensor is transposed
// the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
// note: this is not part of the KV state and it's only used to speed-up the find_slot() method
uint32_t head = 0;
const uint32_t n_seq_max = 1;
const uint32_t n_stream = 1;
// required padding
const uint32_t n_pad = 1;
@ -200,7 +234,17 @@ private:
std::vector<ggml_context_ptr> ctxs;
std::vector<ggml_backend_buffer_ptr> bufs;
llama_kv_cells_unified cells;
// the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
// note: this is not part of the KV state and it's only used to speed-up the find_slot() method
std::vector<uint32_t> v_heads;
std::vector<llama_kv_cells_unified> v_cells;
// maps from a sequence id to a stream id
std::vector<uint32_t> seq_to_stream;
// pending stream copies that will be applied during the next update
stream_copy_info sc_info;
std::vector<kv_layer> layers;
@ -237,18 +281,25 @@ private:
ggml_cgraph * gf,
const defrag_info & dinfo) const;
void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
struct cell_ranges_t {
uint32_t strm;
bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
std::vector<std::pair<uint32_t, uint32_t>> data; // ranges, from inclusive, to exclusive
};
void state_write_meta(llama_io_write_i & io, const cell_ranges_t & cr, llama_seq_id seq_id = -1) const;
void state_write_data(llama_io_write_i & io, const cell_ranges_t & cr) const;
bool state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
bool state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count);
};
class llama_kv_cache_unified_context : public llama_memory_context_i {
public:
// some shorthands
using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
using defrag_info = llama_kv_cache_unified::defrag_info;
using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
using defrag_info = llama_kv_cache_unified::defrag_info;
using stream_copy_info = llama_kv_cache_unified::stream_copy_info;
// used for errors
llama_kv_cache_unified_context(llama_memory_status status);
@ -262,7 +313,8 @@ public:
llama_kv_cache_unified * kv,
llama_context * lctx,
bool do_shift,
defrag_info dinfo);
defrag_info dinfo,
stream_copy_info sc_info);
// used to create a batch procesing context from a batch
llama_kv_cache_unified_context(
@ -320,6 +372,8 @@ private:
defrag_info dinfo;
stream_copy_info sc_info;
//
// batch processing context
//

View file

@ -40,6 +40,7 @@ llama_memory_hybrid::llama_memory_hybrid(
offload,
kv_size,
n_seq_max,
1,
n_pad,
n_swa,
swa_type

View file

@ -854,6 +854,21 @@ void llama_model::load_hparams(llama_model_loader & ml) {
default: type = LLM_TYPE_UNKNOWN;
}
} break;
case LLM_ARCH_DREAM:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
// Dream models are primarily 7B with 28 layers
switch (hparams.n_layer) {
case 28:
type = LLM_TYPE_7B;
break;
default:
type = LLM_TYPE_UNKNOWN;
}
// Set non-causal attention for diffusion models
hparams.causal_attn = false;
}
break;
case LLM_ARCH_QWEN2MOE:
{
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false);
@ -2766,12 +2781,14 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
} break;
case LLM_ARCH_QWEN2:
case LLM_ARCH_QWEN2VL:
case LLM_ARCH_DREAM:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
// output
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), {n_vocab}, TENSOR_NOT_REQUIRED);
// if output is NULL, init from the input tok embed
if (output == NULL) {
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
@ -7849,6 +7866,113 @@ struct llm_build_qwen2 : public llm_graph_context {
// lm_head
cur = build_lora_mm(model.output, cur);
if (model.output_b != nullptr) {
cur = ggml_add(ctx0, cur, model.output_b);
}
cb(cur, "result_output", -1);
res->t_logits = cur;
ggml_build_forward_expand(gf, cur);
}
};
struct llm_build_dream : public llm_graph_context {
llm_build_dream(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) :
llm_graph_context(params) {
//copied from qwen2
const int64_t n_embd_head = hparams.n_embd_head_v;
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_rot);
ggml_tensor * cur;
ggml_tensor * inpL;
inpL = build_inp_embd(model.tok_embd);
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_no_cache();
ggml_tensor * inp_out_ids = build_inp_out_ids();
for (int il = 0; il < n_layer; ++il) {
ggml_tensor * inpSA = inpL;
// norm
cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
cb(cur, "attn_norm", il);
// self-attention
{
// compute Q and K and RoPE them
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
cb(Qcur, "Qcur", il);
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
cb(Kcur, "Kcur", il);
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
cb(Vcur, "Vcur", il);
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il);
cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr,
nullptr, 1.0f / sqrtf(float(n_embd_head)), il);
}
if (il == n_layer - 1 && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
// feed-forward network
cur = build_norm(ffn_inp, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il);
cb(cur, "ffn_norm", il);
cur = build_ffn(cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL,
model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, il);
cb(cur, "ffn_out", il);
cur = ggml_add(ctx0, cur, ffn_inp);
cur = build_cvec(cur, il);
cb(cur, "l_out", il);
// input for next layer
inpL = cur;
}
cur = inpL;
cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1);
cb(cur, "result_norm", -1);
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cb(cur, "result_output", -1);
res->t_logits = cur;
@ -15863,6 +15987,7 @@ private:
cb(zx, "mamba_in_proj", il);
// {8192, 5, 1, 1} -> {8192, 1, 5, 1}
zx = ggml_permute(ctx0, zx, 0, 2, 1, 3);
zx = ggml_cont(ctx0, zx);
zx = ggml_reshape_4d(ctx0, zx, head_dim * 2, n_heads, n_seq_tokens, n_seqs);
cb(zx, "mamba_in_proj_out", il);
@ -15880,7 +16005,6 @@ private:
// conv1d
{
// => {d_conv - 1 + n_seq_tokens, d_inner, n_seqs}
x = ggml_view_2d(ctx0, x, d_inner, n_seq_tokens * n_seqs, d_inner * x->nb[0], 0);
ggml_tensor * conv_x = ggml_concat(ctx0, conv, ggml_transpose(ctx0, x), 0);
cb(conv_x, "mamba_conv1d_input", il);
@ -16587,6 +16711,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
case LLM_ARCH_NOMIC_BERT_MOE:
case LLM_ARCH_NEO_BERT:
case LLM_ARCH_WAVTOKENIZER_DEC:
case LLM_ARCH_DREAM:
{
res = nullptr;
} break;
@ -16627,7 +16752,18 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
} else {
const auto padding = llama_kv_cache_unified::get_padding(cparams);
cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
uint32_t n_ctx_per_stream = cparams.n_ctx;
if (!cparams.kv_unified) {
n_ctx_per_stream = (cparams.n_ctx + cparams.n_seq_max - 1)/cparams.n_seq_max;
n_ctx_per_stream = GGML_PAD(n_ctx_per_stream, padding);
cparams.n_ctx = n_ctx_per_stream*cparams.n_seq_max;
} else {
n_ctx_per_stream = GGML_PAD(n_ctx_per_stream, padding);
cparams.n_ctx = n_ctx_per_stream;
}
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
@ -16641,7 +16777,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
!cparams.flash_attn,
cparams.offload_kqv,
params.swa_full,
cparams.n_ctx,
cparams.kv_unified,
n_ctx_per_stream,
cparams.n_seq_max,
cparams.n_ubatch,
padding);
@ -16655,7 +16792,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
params.type_v,
!cparams.flash_attn,
cparams.offload_kqv,
cparams.n_ctx,
cparams.kv_unified,
n_ctx_per_stream,
cparams.n_seq_max,
padding,
hparams.n_swa,
@ -16738,6 +16876,11 @@ llm_graph_result_ptr llama_model::build_graph(
{
llm = std::make_unique<llm_build_qwen2>(*this, params, gf);
} break;
case LLM_ARCH_DREAM:
{
llm = std::make_unique<llm_build_dream>(*this, params, gf);
}
break;
case LLM_ARCH_QWEN2VL:
{
llm = std::make_unique<llm_build_qwen2vl>(*this, params, gf);
@ -17155,6 +17298,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
case LLM_ARCH_BITNET:
case LLM_ARCH_QWEN:
case LLM_ARCH_QWEN2:
case LLM_ARCH_DREAM:
case LLM_ARCH_QWEN2MOE:
case LLM_ARCH_QWEN3:
case LLM_ARCH_QWEN3MOE:

View file

@ -3635,6 +3635,10 @@ llama_token llama_vocab::token_fim_sep() const {
return pimpl->special_fim_sep_id;
}
llama_token llama_vocab::token_mask() const {
return pimpl->special_mask_id;
}
bool llama_vocab::get_add_space_prefix() const {
return pimpl->add_space_prefix;
}
@ -3882,6 +3886,10 @@ llama_token llama_vocab_fim_sep(const struct llama_vocab * vocab) {
return vocab->token_fim_sep();
}
llama_token llama_vocab_mask(const struct llama_vocab* vocab) {
return vocab->token_mask();
}
// deprecated
const char * llama_token_get_text(const struct llama_vocab * vocab, llama_token token) {
return llama_vocab_get_text(vocab, token);

View file

@ -103,6 +103,7 @@ struct llama_vocab {
llama_token token_sep() const;
llama_token token_nl () const;
llama_token token_pad() const;
llama_token token_mask() const;
llama_token token_prefix() const;
llama_token token_middle() const;

View file

@ -127,7 +127,6 @@ struct slot_params {
std::vector<std::string> response_fields;
bool timings_per_token = false;
bool post_sampling_probs = false;
bool ignore_eos = false;
struct common_params_sampling sampling;
struct common_params_speculative speculative;
@ -441,7 +440,6 @@ struct server_task {
{
params.sampling.logit_bias.clear();
params.ignore_eos = json_value(data, "ignore_eos", false);
const auto & logit_bias = data.find("logit_bias");
if (logit_bias != data.end() && logit_bias->is_array()) {
@ -472,6 +470,13 @@ struct server_task {
}
}
}
params.sampling.ignore_eos = json_value(data, "ignore_eos", params_base.sampling.ignore_eos);
if (params.sampling.ignore_eos) {
params.sampling.logit_bias.insert(
params.sampling.logit_bias.end(),
defaults.sampling.logit_bias_eog.begin(), defaults.sampling.logit_bias_eog.end());
}
}
{
@ -1898,7 +1903,6 @@ struct server_context {
bool clean_kv_cache = true;
bool add_bos_token = true;
bool has_eos_token = false;
int32_t n_ctx; // total context for all clients / slots
@ -1957,7 +1961,6 @@ struct server_context {
n_ctx = llama_n_ctx(ctx);
add_bos_token = llama_vocab_get_add_bos(vocab);
has_eos_token = llama_vocab_eos(vocab) != LLAMA_TOKEN_NULL;
if (!params_base.speculative.model.path.empty() || !params_base.speculative.model.hf_repo.empty()) {
SRV_INF("loading draft model '%s'\n", params_base.speculative.model.path.c_str());
@ -2217,10 +2220,6 @@ struct server_context {
slot.params.n_predict = slot.n_predict;
}
if (slot.params.ignore_eos && has_eos_token) {
slot.params.sampling.logit_bias.push_back({llama_vocab_eos(vocab), -INFINITY});
}
{
if (slot.smpl != nullptr) {
common_sampler_free(slot.smpl);