Merge branch 'upstream' into concedo_experimental

# Conflicts:
#	.github/workflows/winget.yml
#	CODEOWNERS
#	common/CMakeLists.txt
#	common/arg.cpp
#	docs/ops/SYCL.csv
#	examples/lookup/lookup-create.cpp
#	examples/lookup/lookup-stats.cpp
#	examples/lookup/lookup.cpp
#	examples/speculative-simple/speculative-simple.cpp
#	examples/speculative/speculative.cpp
#	ggml/src/ggml-hip/CMakeLists.txt
#	ggml/src/ggml-sycl/dpct/helper.hpp
#	ggml/src/ggml-sycl/ggml-sycl.cpp
#	ggml/src/ggml-sycl/norm.cpp
#	ggml/src/ggml-zendnn/ggml-zendnn.cpp
#	tests/test-chat-template.cpp
This commit is contained in:
Concedo 2026-01-29 23:05:05 +08:00
commit 7e755014b2
26 changed files with 2773 additions and 925 deletions

View file

@ -716,25 +716,25 @@ clean:
rm -vrf llguidance
# useful tools
main: tools/completion/completion.cpp common/arg.cpp common/chat.cpp common/preset.cpp common/download.cpp build-info.h ggml.o ggml-cpu.o ggml-ops.o ggml-vec.o ggml-binops.o ggml-unops.o llama.o console.o llavaclip_default.o llava.o ggml-backend_default.o ggml-backend-reg_default.o ggml-repack.o $(OBJS_FULL) $(OBJS)
main: tools/completion/completion.cpp common/arg.cpp common/speculative.cpp common/ngram-cache.cpp common/ngram-map.cpp common/chat.cpp common/preset.cpp common/download.cpp build-info.h ggml.o ggml-cpu.o ggml-ops.o ggml-vec.o ggml-binops.o ggml-unops.o llama.o console.o llavaclip_default.o llava.o ggml-backend_default.o ggml-backend-reg_default.o ggml-repack.o $(OBJS_FULL) $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
mainvk: tools/completion/completion.cpp common/arg.cpp common/chat.cpp common/preset.cpp common/download.cpp build-info.h ggml_v4_vulkan.o ggml-cpu.o ggml-ops.o ggml-vec.o ggml-binops.o ggml-unops.o llama.o console.o llavaclip_vulkan.o llava.o ggml-backend_vulkan.o ggml-backend-reg_vulkan.o ggml-vulkan.o ggml-vulkan-shaders.o ggml-repack.o $(OBJS_FULL) $(OBJS) lib/vulkan-1.lib
mainvk: tools/completion/completion.cpp common/arg.cpp common/speculative.cpp common/ngram-cache.cpp common/ngram-map.cpp common/chat.cpp common/preset.cpp common/download.cpp build-info.h ggml_v4_vulkan.o ggml-cpu.o ggml-ops.o ggml-vec.o ggml-binops.o ggml-unops.o llama.o console.o llavaclip_vulkan.o llava.o ggml-backend_vulkan.o ggml-backend-reg_vulkan.o ggml-vulkan.o ggml-vulkan-shaders.o ggml-repack.o $(OBJS_FULL) $(OBJS) lib/vulkan-1.lib
$(CXX) $(CXXFLAGS) -DGGML_USE_VULKAN -DSD_USE_VULKAN $(filter-out %.h,$^) -o $@ $(LDFLAGS)
fitparams: tools/fit-params/fit-params.cpp common/arg.cpp common/chat.cpp common/preset.cpp common/download.cpp build-info.h ggml_v4_vulkan.o ggml-cpu.o ggml-ops.o ggml-vec.o ggml-binops.o ggml-unops.o llama.o console.o llavaclip_vulkan.o llava.o ggml-backend_vulkan.o ggml-backend-reg_vulkan.o ggml-vulkan.o ggml-vulkan-shaders.o ggml-repack.o $(OBJS_FULL) $(OBJS) lib/vulkan-1.lib
fitparams: tools/fit-params/fit-params.cpp common/arg.cpp common/speculative.cpp common/ngram-cache.cpp common/ngram-map.cpp common/chat.cpp common/preset.cpp common/download.cpp build-info.h ggml_v4_vulkan.o ggml-cpu.o ggml-ops.o ggml-vec.o ggml-binops.o ggml-unops.o llama.o console.o llavaclip_vulkan.o llava.o ggml-backend_vulkan.o ggml-backend-reg_vulkan.o ggml-vulkan.o ggml-vulkan-shaders.o ggml-repack.o $(OBJS_FULL) $(OBJS) lib/vulkan-1.lib
$(CXX) $(CXXFLAGS) -DGGML_USE_VULKAN -DSD_USE_VULKAN $(filter-out %.h,$^) -o $@ $(LDFLAGS)
sdmain: otherarch/sdcpp/util.cpp otherarch/sdcpp/main.cpp otherarch/sdcpp/stable-diffusion.cpp otherarch/sdcpp/upscaler.cpp otherarch/sdcpp/model.cpp otherarch/sdcpp/name_conversion.cpp otherarch/sdcpp/tokenize_util.cpp otherarch/sdcpp/version.cpp otherarch/sdcpp/thirdparty/zip.c build-info.h ggml.o ggml-cpu.o ggml-ops.o ggml-vec.o ggml-binops.o ggml-unops.o llama.o console.o llavaclip_default.o llava.o ggml-backend_default.o ggml-backend-reg_default.o ggml-repack.o $(OBJS_FULL) $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
whispermain: otherarch/whispercpp/main.cpp otherarch/whispercpp/whisper.cpp build-info.h ggml.o ggml-cpu.o ggml-ops.o ggml-vec.o ggml-binops.o ggml-unops.o llama.o console.o llavaclip_default.o llava.o ggml-backend_default.o ggml-backend-reg_default.o ggml-repack.o $(OBJS_FULL) $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
ttsmain: tools/tts/tts.cpp common/arg.cpp common/chat.cpp common/preset.cpp common/download.cpp build-info.h ggml.o ggml-cpu.o ggml-ops.o ggml-vec.o ggml-binops.o ggml-unops.o llama.o console.o llavaclip_default.o llava.o ggml-backend_default.o ggml-backend-reg_default.o ggml-repack.o $(OBJS_FULL) $(OBJS)
ttsmain: tools/tts/tts.cpp common/arg.cpp common/speculative.cpp common/ngram-cache.cpp common/ngram-map.cpp common/chat.cpp common/preset.cpp common/download.cpp build-info.h ggml.o ggml-cpu.o ggml-ops.o ggml-vec.o ggml-binops.o ggml-unops.o llama.o console.o llavaclip_default.o llava.o ggml-backend_default.o ggml-backend-reg_default.o ggml-repack.o $(OBJS_FULL) $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
gguf-split: tools/gguf-split/gguf-split.cpp ggml.o ggml-cpu.o ggml-ops.o ggml-vec.o ggml-binops.o ggml-unops.o llama.o build-info.h llavaclip_default.o llava.o ggml-backend_default.o ggml-backend-reg_default.o ggml-repack.o $(OBJS_FULL) $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
mtmd-cli: tools/mtmd/mtmd-cli.cpp tools/mtmd/mtmd.cpp tools/mtmd/mtmd-helper.cpp tools/mtmd/clip.cpp common/debug.cpp common/arg.cpp common/chat.cpp common/preset.cpp common/download.cpp build-info.h ggml.o ggml-cpu.o ggml-ops.o ggml-vec.o ggml-binops.o ggml-unops.o llama.o console.o ggml-backend_default.o ggml-backend-reg_default.o ggml-repack.o $(OBJS_FULL) $(OBJS)
mtmd-cli: tools/mtmd/mtmd-cli.cpp tools/mtmd/mtmd.cpp tools/mtmd/mtmd-helper.cpp tools/mtmd/clip.cpp common/debug.cpp common/arg.cpp common/speculative.cpp common/ngram-cache.cpp common/ngram-map.cpp common/chat.cpp common/preset.cpp common/download.cpp build-info.h ggml.o ggml-cpu.o ggml-ops.o ggml-vec.o ggml-binops.o ggml-unops.o llama.o console.o ggml-backend_default.o ggml-backend-reg_default.o ggml-repack.o $(OBJS_FULL) $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
embedding: examples/embedding/embedding.cpp common/arg.cpp common/chat.cpp common/preset.cpp common/download.cpp src/llama-cparams.cpp build-info.h ggml.o ggml-cpu.o ggml-ops.o ggml-vec.o ggml-binops.o ggml-unops.o llama.o console.o llavaclip_default.o llava.o ggml-backend_default.o ggml-backend-reg_default.o ggml-repack.o $(OBJS_FULL) $(OBJS)
embedding: examples/embedding/embedding.cpp common/arg.cpp common/speculative.cpp common/ngram-cache.cpp common/ngram-map.cpp common/chat.cpp common/preset.cpp common/download.cpp src/llama-cparams.cpp build-info.h ggml.o ggml-cpu.o ggml-ops.o ggml-vec.o ggml-binops.o ggml-unops.o llama.o console.o llavaclip_default.o llava.o ggml-backend_default.o ggml-backend-reg_default.o ggml-repack.o $(OBJS_FULL) $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
embeddingvk: examples/embedding/embedding.cpp common/arg.cpp common/chat.cpp common/preset.cpp common/download.cpp src/llama-cparams.cpp build-info.h ggml_v4_vulkan.o ggml-cpu.o ggml-ops.o ggml-vec.o ggml-binops.o ggml-unops.o llama.o console.o llavaclip_vulkan.o llava.o ggml-backend_vulkan.o ggml-backend-reg_vulkan.o ggml-vulkan.o ggml-vulkan-shaders.o ggml-repack.o $(OBJS_FULL) $(OBJS) lib/vulkan-1.lib
embeddingvk: examples/embedding/embedding.cpp common/arg.cpp common/speculative.cpp common/ngram-cache.cpp common/ngram-map.cpp common/chat.cpp common/preset.cpp common/download.cpp src/llama-cparams.cpp build-info.h ggml_v4_vulkan.o ggml-cpu.o ggml-ops.o ggml-vec.o ggml-binops.o ggml-unops.o llama.o console.o llavaclip_vulkan.o llava.o ggml-backend_vulkan.o ggml-backend-reg_vulkan.o ggml-vulkan.o ggml-vulkan-shaders.o ggml-repack.o $(OBJS_FULL) $(OBJS) lib/vulkan-1.lib
$(CXX) $(CXXFLAGS) -DGGML_USE_VULKAN -DSD_USE_VULKAN $(filter-out %.h,$^) -o $@ $(LDFLAGS)
ttscppmain: otherarch/ttscpp/cli/cli.cpp otherarch/ttscpp/cli/playback.cpp otherarch/ttscpp/cli/playback.h otherarch/ttscpp/cli/write_file.cpp otherarch/ttscpp/cli/write_file.h otherarch/ttscpp/cli/vad.cpp otherarch/ttscpp/cli/vad.h otherarch/ttscpp/src/ttscpp.cpp otherarch/ttscpp/src/ttstokenizer.cpp otherarch/ttscpp/src/ttssampler.cpp otherarch/ttscpp/src/parler_model.cpp otherarch/ttscpp/src/dac_model.cpp otherarch/ttscpp/src/ttsutil.cpp otherarch/ttscpp/src/ttsargs.cpp otherarch/ttscpp/src/ttst5_encoder_model.cpp otherarch/ttscpp/src/phonemizer.cpp otherarch/ttscpp/src/tts_model.cpp otherarch/ttscpp/src/kokoro_model.cpp otherarch/ttscpp/src/dia_model.cpp otherarch/ttscpp/src/orpheus_model.cpp otherarch/ttscpp/src/snac_model.cpp otherarch/ttscpp/src/general_neural_audio_codec.cpp ggml.o ggml-cpu.o ggml-ops.o ggml-vec.o ggml-binops.o ggml-unops.o llama.o console.o llavaclip_default.o llava.o ggml-backend_default.o ggml-backend-reg_default.o ggml-repack.o $(OBJS_FULL) $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)

View file

@ -6,6 +6,7 @@
#include "json-schema-to-grammar.h"
#include "log.h"
#include "sampling.h"
#include "speculative.h"
#include "chat.h"
#include "build-info.h"
#include "preset.h"
@ -582,14 +583,14 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
params.mmproj = res.mmproj;
}
// only download mmproj if the current example is using it
for (auto & ex : mmproj_examples) {
for (const auto & ex : mmproj_examples) {
if (ctx_arg.ex == ex) {
common_params_handle_model(params.mmproj, params.hf_token, params.offline);
break;
}
}
common_params_handle_model(params.speculative.model, params.hf_token, params.offline);
common_params_handle_model(params.vocoder.model, params.hf_token, params.offline);
common_params_handle_model(params.speculative.mparams_dft, params.hf_token, params.offline);
common_params_handle_model(params.vocoder.model, params.hf_token, params.offline);
}
// model is required (except for server)
@ -1219,16 +1220,16 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
{"-lcs", "--lookup-cache-static"}, "FNAME",
"path to static lookup cache to use for lookup decoding (not updated by generation)",
[](common_params & params, const std::string & value) {
params.lookup_cache_static = value;
params.speculative.lookup_cache_static = value;
}
).set_examples({LLAMA_EXAMPLE_LOOKUP}));
).set_examples({LLAMA_EXAMPLE_LOOKUP, LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"-lcd", "--lookup-cache-dynamic"}, "FNAME",
"path to dynamic lookup cache to use for lookup decoding (updated by generation)",
[](common_params & params, const std::string & value) {
params.lookup_cache_dynamic = value;
params.speculative.lookup_cache_dynamic = value;
}
).set_examples({LLAMA_EXAMPLE_LOOKUP}));
).set_examples({LLAMA_EXAMPLE_LOOKUP, LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"-c", "--ctx-size"}, "N",
string_format("size of the prompt context (default: %d, 0 = loaded from model)", params.n_ctx),
@ -1303,7 +1304,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
[](common_params & params, bool value) {
params.kv_unified = value;
}
).set_env("LLAMA_ARG_KV_UNIFIED").set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_BATCHED}));
).set_env("LLAMA_ARG_KV_UNIFIED").set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_BATCHED, LLAMA_EXAMPLE_BENCH}));
add_opt(common_arg(
{"--context-shift"},
{"--no-context-shift"},
@ -2566,7 +2567,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
{"-hfd", "-hfrd", "--hf-repo-draft"}, "<user>/<model>[:quant]",
"Same as --hf-repo, but for the draft model (default: unused)",
[](common_params & params, const std::string & value) {
params.speculative.model.hf_repo = value;
params.speculative.mparams_dft.hf_repo = value;
}
).set_env("LLAMA_ARG_HFD_REPO"));
add_opt(common_arg(
@ -3387,7 +3388,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
{"-md", "--model-draft"}, "FNAME",
"draft model for speculative decoding (default: unused)",
[](common_params & params, const std::string & value) {
params.speculative.model.path = value;
params.speculative.mparams_dft.path = value;
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_MODEL_DRAFT"));
add_opt(common_arg(
@ -3397,6 +3398,66 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.speculative.replacements.push_back({ tgt, dft });
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}));
add_opt(common_arg(
{"--spec-type"}, "[none|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v]",
string_format("type of speculative decoding to use when no draft model is provided (default: %s)\n",
common_speculative_type_to_str(params.speculative.type).c_str()),
[](common_params & params, const std::string & value) {
if (value == "none") {
params.speculative.type = COMMON_SPECULATIVE_TYPE_NONE;
} else if (value == "ngram-cache") {
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_CACHE;
} else if (value == "ngram-simple") {
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE;
} else if (value == "ngram-map-k") {
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K;
} else if (value == "ngram-map-k4v") {
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V;
} else {
throw std::invalid_argument("unknown speculative decoding type without draft model");
}
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"--spec-ngram-size-n"}, "N",
string_format("ngram size N for ngram-simple/ngram-map speculative decoding, length of lookup n-gram (default: %d)", params.speculative.ngram_size_n),
[](common_params & params, int value) {
if (value < 1 || value > 1024) {
throw std::invalid_argument("ngram size N must be between 1 and 1024 inclusive");
}
params.speculative.ngram_size_n = value;
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"--spec-ngram-size-m"}, "N",
string_format("ngram size M for ngram-simple/ngram-map speculative decoding, length of draft m-gram (default: %d)", params.speculative.ngram_size_m),
[](common_params & params, int value) {
if (value < 1 || value > 1024) {
throw std::invalid_argument("ngram size M must be between 1 and 1024 inclusive");
}
params.speculative.ngram_size_m = value;
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"--spec-ngram-check-rate"}, "N",
string_format("ngram check rate for ngram-simple/ngram-map speculative decoding (default: %d)", params.speculative.ngram_check_rate),
[](common_params & params, int value) {
if (value < 1) {
throw std::invalid_argument("ngram check rate must be at least 1");
}
params.speculative.ngram_check_rate = value;
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"--spec-ngram-min-hits"}, "N",
string_format("minimum hits for ngram-map speculative decoding (default: %d)", params.speculative.ngram_min_hits),
[](common_params & params, int value) {
if (value < 1) {
throw std::invalid_argument("ngram min hits must be at least 1");
}
params.speculative.ngram_min_hits = value;
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"-ctkd", "--cache-type-k-draft"}, "TYPE",
string_format(
@ -3623,8 +3684,8 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
[](common_params & params) {
params.model.hf_repo = "ggml-org/Qwen2.5-Coder-7B-Q8_0-GGUF";
params.model.hf_file = "qwen2.5-coder-7b-q8_0.gguf";
params.speculative.model.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF";
params.speculative.model.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf";
params.speculative.mparams_dft.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF";
params.speculative.mparams_dft.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf";
params.port = 8012;
params.n_ubatch = 1024;
params.n_batch = 1024;
@ -3639,8 +3700,8 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
[](common_params & params) {
params.model.hf_repo = "ggml-org/Qwen2.5-Coder-14B-Q8_0-GGUF";
params.model.hf_file = "qwen2.5-coder-14b-q8_0.gguf";
params.speculative.model.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF";
params.speculative.model.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf";
params.speculative.mparams_dft.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF";
params.speculative.mparams_dft.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf";
params.port = 8012;
params.n_ubatch = 1024;
params.n_batch = 1024;

View file

@ -785,10 +785,12 @@ static std::string apply(
nlohmann::ordered_json inp = nlohmann::ordered_json{
{"messages", messages_override.has_value() ? *messages_override : inputs.messages},
{"tools", tools_override.has_value() ? *tools_override : inputs.tools},
{"bos_token", tmpl.bos_token()},
{"eos_token", tmpl.eos_token()},
};
if (tools_override.has_value() || !inputs.tools.empty()) {
inp["tools"] = tools_override.has_value() ? *tools_override : inputs.tools;
}
if (inputs.extra_context.is_object()) {
// TODO: do we need to merge, or replacing is fine?
for (const auto & [k, v] : inputs.extra_context.items()) {
@ -804,9 +806,6 @@ static std::string apply(
if (inputs.add_generation_prompt) {
inp["add_generation_prompt"] = true;
}
if (inp["tools"].is_null()) {
inp["tools"] = json::array();
}
jinja::global_from_json(ctx, inp, inputs.mark_input);
@ -2233,12 +2232,11 @@ static common_chat_params common_chat_params_init_glm_4_5(const common_chat_temp
static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) {
LOG_DBG("%s\n", __func__);
common_chat_params data;
const std::optional<json> tools_override = json();
const std::optional<json> additional_context = json {
{"datetime", format_time(inputs.now, "%b %d %Y %H:%M:%S GMT")},
{"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))},
};
data.prompt = apply(tmpl, inputs, /* messages_override =*/ std::nullopt, tools_override, additional_context);
data.prompt = apply(tmpl, inputs, /* messages_override =*/ std::nullopt, /* tools_override =*/ std::nullopt, additional_context);
if (inputs.tools.is_array() && !inputs.tools.empty()) {
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
data.grammar = build_grammar([&](const common_grammar_builder & builder) {

View file

@ -1104,7 +1104,10 @@ common_init_result::common_init_result(common_params & params) :
if (params.fit_params) {
LOG_INF("%s: fitting params to device memory, for bugs during this step try to reproduce them with -fit off, or provide --verbose logs if the bug only occurs with -fit on\n", __func__);
llama_params_fit(params.model.path.c_str(), &mparams, &cparams,
params.tensor_split, params.tensor_buft_overrides.data(), params.fit_params_target.data(), params.fit_params_min_ctx,
params.tensor_split,
params.tensor_buft_overrides.data(),
params.fit_params_target.data(),
params.fit_params_min_ctx,
params.verbosity >= 4 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_ERROR);
}
@ -1215,10 +1218,6 @@ std::vector<llama_adapter_lora_ptr> & common_init_result::lora() {
return pimpl->lora;
}
void common_init_result::free_context() {
pimpl->context.reset();
}
common_init_result_ptr common_init_from_params(common_params & params) {
common_init_result_ptr res(new common_init_result(params));

View file

@ -161,6 +161,16 @@ enum common_params_sampling_config : uint64_t {
COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA = 1 << 11,
};
enum common_speculative_type {
COMMON_SPECULATIVE_TYPE_NONE, // no speculative decoding
COMMON_SPECULATIVE_TYPE_DRAFT, // draft model
COMMON_SPECULATIVE_TYPE_EAGLE3, // eagle draft model
COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, // simple self-speculative decoding
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, // self-speculative decoding with n-gram keys only
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, // self-speculative decoding with n-gram keys and 4 m-gram values
COMMON_SPECULATIVE_TYPE_NGRAM_CACHE, // self-speculative decoding with 3-level n-gram cache
COMMON_SPECULATIVE_TYPE_COUNT // number of types, unknown type
};
// sampling parameters
struct common_params_sampling {
@ -240,16 +250,35 @@ struct common_params_model {
};
struct common_params_speculative {
std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
common_speculative_type type = COMMON_SPECULATIVE_TYPE_NONE; // type of speculative decoding
int32_t n_ctx = 0; // draft context size
int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding
int32_t n_min = 0; // minimum number of draft tokens to use for speculative decoding
int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
float p_split = 0.1f; // speculative decoding split probability
float p_min = 0.75f; // minimum speculative decoding probability (greedy)
std::vector<std::pair<std::string, std::string>> replacements; // main to speculative model replacements
std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;
// general-purpose speculative decoding parameters
int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding
int32_t n_min = 0; // minimum number of draft tokens to use for speculative decoding
float p_split = 0.1f; // speculative decoding split probability
float p_min = 0.75f; // minimum speculative decoding probability (greedy)
// ngram-based speculative decoding
uint16_t ngram_size_n = 12; // ngram size for lookup
uint16_t ngram_size_m = 48; // mgram size for speculative tokens
uint16_t ngram_check_rate = 1; // check rate for ngram lookup
uint16_t ngram_min_hits = 1; // minimum hits at ngram/mgram lookup for mgram to be proposed
std::string lookup_cache_static; // path of static ngram cache file for lookup decoding // NOLINT
std::string lookup_cache_dynamic; // path of dynamic ngram cache file for lookup decoding // NOLINT
// draft-model speculative decoding
struct common_params_model mparams_dft;
llama_model * model_dft = nullptr; // a llama_model that can be shared by multiple speculative contexts
llama_context_params cparams_dft; // these are the parameters for the draft llama_context
int32_t n_ctx = 0; // draft context size
int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K
ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V
@ -257,7 +286,14 @@ struct common_params_speculative {
struct cpu_params cpuparams;
struct cpu_params cpuparams_batch;
struct common_params_model model;
std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
std::vector<std::pair<std::string, std::string>> replacements; // main to speculative model replacements
std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;
bool has_dft() const {
return !mparams_dft.path.empty() || !mparams_dft.hf_repo.empty();
}
};
struct common_params_vocoder {
@ -375,8 +411,6 @@ struct common_params {
std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state // NOLINT
std::string input_prefix = ""; // string to prefix user inputs with // NOLINT
std::string input_suffix = ""; // string to suffix user inputs with // NOLINT
std::string lookup_cache_static = ""; // path of static ngram cache file for lookup decoding // NOLINT
std::string lookup_cache_dynamic = ""; // path of dynamic ngram cache file for lookup decoding // NOLINT
std::string logits_file = ""; // file for saving *all* logits // NOLINT
// llama-debug specific options
@ -572,10 +606,6 @@ struct common_params {
// return false from callback to abort model loading or true to continue
llama_progress_callback load_progress_callback = NULL;
void * load_progress_callback_user_data = NULL;
bool has_speculative() const {
return !speculative.model.path.empty() || !speculative.model.hf_repo.empty();
}
};
// call once at the start of a program if it uses libcommon
@ -711,8 +741,6 @@ struct common_init_result {
std::vector<llama_adapter_lora_ptr> & lora();
void free_context();
private:
struct impl;
std::unique_ptr<impl> pimpl;

View file

@ -114,6 +114,18 @@ static T slice(const T & array, int64_t start, int64_t stop, int64_t step = 1) {
return result;
}
template<typename T>
static value empty_value_fn(const func_args &) {
if constexpr (std::is_same_v<T, value_int>) {
return mk_val<T>(0);
} else if constexpr (std::is_same_v<T, value_float>) {
return mk_val<T>(0.0);
} else if constexpr (std::is_same_v<T, value_bool>) {
return mk_val<T>(false);
} else {
return mk_val<T>();
}
}
template<typename T>
static value test_type_fn(const func_args & args) {
args.ensure_count(1);
@ -128,6 +140,13 @@ static value test_type_fn(const func_args & args) {
JJ_DEBUG("test_type_fn: type=%s or %s result=%d", typeid(T).name(), typeid(U).name(), is_type ? 1 : 0);
return mk_val<value_bool>(is_type);
}
template<typename T, typename U, typename V>
static value test_type_fn(const func_args & args) {
args.ensure_count(1);
bool is_type = is_val<T>(args.get_pos(0)) || is_val<U>(args.get_pos(0)) || is_val<V>(args.get_pos(0));
JJ_DEBUG("test_type_fn: type=%s, %s or %s result=%d", typeid(T).name(), typeid(U).name(), typeid(V).name(), is_type ? 1 : 0);
return mk_val<value_bool>(is_type);
}
template<value_compare_op op>
static value test_compare_fn(const func_args & args) {
args.ensure_count(2, 2);
@ -347,8 +366,8 @@ const func_builtins & global_builtins() {
{"test_is_integer", test_type_fn<value_int>},
{"test_is_float", test_type_fn<value_float>},
{"test_is_number", test_type_fn<value_int, value_float>},
{"test_is_iterable", test_type_fn<value_array, value_string>},
{"test_is_sequence", test_type_fn<value_array, value_string>},
{"test_is_iterable", test_type_fn<value_array, value_string, value_undefined>},
{"test_is_sequence", test_type_fn<value_array, value_string, value_undefined>},
{"test_is_mapping", test_type_fn<value_object>},
{"test_is_lower", [](const func_args & args) -> value {
args.ensure_vals<value_string>();
@ -1003,7 +1022,22 @@ const func_builtins & value_none_t::get_builtins() const {
static const func_builtins builtins = {
{"default", default_value},
{"tojson", tojson},
{"string", [](const func_args &) -> value { return mk_val<value_string>("None"); }}
{"string", [](const func_args &) -> value {
return mk_val<value_string>("None");
}},
{"safe", [](const func_args &) -> value {
return mk_val<value_string>("None");
}},
{"strip", [](const func_args &) -> value {
return mk_val<value_string>("None");
}},
{"items", empty_value_fn<value_array>},
{"map", empty_value_fn<value_array>},
{"reject", empty_value_fn<value_array>},
{"rejectattr", empty_value_fn<value_array>},
{"select", empty_value_fn<value_array>},
{"selectattr", empty_value_fn<value_array>},
{"unique", empty_value_fn<value_array>},
};
return builtins;
}
@ -1012,10 +1046,33 @@ const func_builtins & value_none_t::get_builtins() const {
const func_builtins & value_undefined_t::get_builtins() const {
static const func_builtins builtins = {
{"default", default_value},
{"tojson", [](const func_args & args) -> value {
args.ensure_vals<value_undefined>();
return mk_val<value_string>("null");
}},
{"capitalize", empty_value_fn<value_string>},
{"first", empty_value_fn<value_undefined>},
{"items", empty_value_fn<value_array>},
{"join", empty_value_fn<value_string>},
{"last", empty_value_fn<value_undefined>},
{"length", empty_value_fn<value_int>},
{"list", empty_value_fn<value_array>},
{"lower", empty_value_fn<value_string>},
{"map", empty_value_fn<value_array>},
{"max", empty_value_fn<value_undefined>},
{"min", empty_value_fn<value_undefined>},
{"reject", empty_value_fn<value_array>},
{"rejectattr", empty_value_fn<value_array>},
{"replace", empty_value_fn<value_string>},
{"reverse", empty_value_fn<value_array>},
{"safe", empty_value_fn<value_string>},
{"select", empty_value_fn<value_array>},
{"selectattr", empty_value_fn<value_array>},
{"sort", empty_value_fn<value_array>},
{"string", empty_value_fn<value_string>},
{"strip", empty_value_fn<value_string>},
{"sum", empty_value_fn<value_int>},
{"title", empty_value_fn<value_string>},
{"truncate", empty_value_fn<value_string>},
{"unique", empty_value_fn<value_array>},
{"upper", empty_value_fn<value_string>},
{"wordcount", empty_value_fn<value_int>},
};
return builtins;
}

View file

@ -192,12 +192,12 @@ void common_ngram_cache_draft(
break;
}
LOG(" - draft candidate: token=%d\n", drafted_token);
LOG_DBG(" - draft candidate: token=%d\n", drafted_token);
draft.push_back(drafted_token);
}
}
void common_ngram_cache_save(common_ngram_cache & ngram_cache, std::string & filename) {
void common_ngram_cache_save(common_ngram_cache & ngram_cache, const std::string & filename) {
std::ofstream file_out(filename, std::ios::binary);
for (std::pair<common_ngram, common_ngram_cache_part> item : ngram_cache) {
const common_ngram ngram = item.first;
@ -217,10 +217,9 @@ void common_ngram_cache_save(common_ngram_cache & ngram_cache, std::string & fil
file_out.write(reinterpret_cast<const char *>(&count), sizeof(int32_t));
}
}
}
common_ngram_cache common_ngram_cache_load(std::string & filename) {
common_ngram_cache common_ngram_cache_load(const std::string & filename) {
std::ifstream hashmap_file(filename, std::ios::binary);
if (!hashmap_file) {
throw std::ifstream::failure("Unable to open file " + filename);

View file

@ -88,12 +88,12 @@ void common_ngram_cache_draft(
// Save an ngram cache to a file.
// ngram_cache: the ngram cache to save.
// filename: the path under which to save the ngram cache.
void common_ngram_cache_save(common_ngram_cache & ngram_cache, std::string & filename);
void common_ngram_cache_save(common_ngram_cache & ngram_cache, const std::string & filename);
// Load an ngram cache saved with common_ngram_cache_save.
// filename: the path from which to load the ngram cache.
// returns: an ngram cache containing the information saved to filename.
common_ngram_cache common_ngram_cache_load(std::string & filename);
common_ngram_cache common_ngram_cache_load(const std::string & filename);
// Merge two ngram caches.
// ngram_cache_target: the ngram cache to which to add the information from ngram_cache_add.

367
common/ngram-map.cpp Normal file
View file

@ -0,0 +1,367 @@
#include "common.h"
#include "log.h"
#include "ngram-map.h"
#include <cinttypes>
#include <cstdint>
#include <cstdio>
#include <sstream>
// n-gram simple
//
/**
* Perform speculative generation using the model's own token history.
* Searches for a matching pattern in the token history and returns draft tokens.
*
* @param state Current state of this implementation
* @param tokens Token history to search in
* @param sampled Last sampled token
* @return Vector of draft tokens, empty if no matching pattern is found
*/
llama_tokens common_ngram_simple_draft(
common_ngram_simple_state & state,
const llama_tokens & tokens, llama_token sampled) {
// Simple implementation of self-speculative decoding without a draft model.
//
const size_t cur_len = tokens.size();
// Only check every check_rate tokens to save compute
// i.e., perform check if (cur_len - idx_last_check) >= check_rate
if (state.idx_last_check + state.config.check_rate > cur_len) {
llama_tokens draft_tokens;
return draft_tokens;
}
size_t n_draft_min = state.config.size_ngram; // size of n-gram to lookup in token history
size_t n_draft_max = state.config.size_mgram; // the m-gram following the found n-gram is used for draft
// vector for tokens we want to verify.
// return empty vector if there is no match.
llama_tokens draft_tokens;
// We need at least n_draft_min + n_draft_max + 1 tokens.
if (cur_len <= static_cast<size_t>(n_draft_min + n_draft_max + 1)) {
return draft_tokens;
}
// pattern search
llama_tokens pattern;
pattern.reserve(n_draft_min);
for (size_t j = cur_len - n_draft_min + 1; j < cur_len; ++j) {
pattern.push_back(tokens[j]);
}
pattern.push_back(sampled); // add the last token to the pattern
// We do a search in the token history.
state.idx_last_check = cur_len;
size_t match_pos = 0; // we ignore position 0, position 0 == no match
// search backwards, but skip the current match (we are currently there)
for (size_t j = cur_len - n_draft_min - 1; j > 0; --j) {
bool match = true;
for (size_t k = 0; k < pattern.size(); ++k) {
if (tokens[j + k] != pattern[k]) {
match = false;
break;
}
}
if (match) {
match_pos = j;
break;
}
}
if (match_pos == 0) {
return draft_tokens;
}
const size_t copy_max = std::min(
n_draft_max,
cur_len - (match_pos + n_draft_min)
);
if (copy_max < n_draft_min) {
return draft_tokens;
}
LOG_DBG("%s: #tokens = %zu: found matching pattern at pos %zu, length %zu, draft length %zu\n",
__func__, cur_len,
match_pos, pattern.size(), copy_max);
draft_tokens.reserve(copy_max);
for (size_t j = 0; j < copy_max; ++j) {
draft_tokens.push_back(tokens[match_pos + n_draft_min + j]);
}
return draft_tokens;
}
// n-gram map
//
// maximum number of counted values of a ngram map value.
#define COMMON_NGRAM_MAX_VALUE_COUNT 16380
static std::string common_tokens_to_str(const llama_tokens & inp, size_t start, size_t length);
void common_ngram_map_draft(common_ngram_map & map,
const llama_tokens & inp, llama_token sampled,
llama_tokens & draft) {
// reset last key and value.
map.last_draft_created = false;
map.last_draft_key_idx = 0;
map.last_draft_value_idx = 0;
const size_t cur_len = inp.size();
const uint16_t n = map.size_key;
const uint16_t m = map.size_value;
if (cur_len < static_cast<size_t>(2 * n + m)) {
return;
}
// Only check every check_rate tokens to save compute
// i.e., perform check if (cur_len - idx_last_check) >= check_rate
if (map.idx_last_check + map.check_rate > cur_len) {
return;
}
map.idx_last_check = cur_len;
// search pattern, the key n-gram
std::vector<llama_token> key_tokens;
key_tokens.reserve(n);
for (size_t j = cur_len - n + 1; j < cur_len; ++j) {
key_tokens.push_back(inp[j]);
}
key_tokens.push_back(sampled);
// search for the key in the map
size_t match_pos = 0;
for (size_t j = cur_len - n - m - 1; j > 0; --j) {
bool match = true;
for (size_t k = 0; k < n; ++k) {
if (inp[j + k] != key_tokens[k]) {
match = false;
break;
}
}
if (match) {
match_pos = j;
break;
}
}
if (match_pos > 0) {
LOG_INF("%s: cur_len = %zu, n = %d, m = %d, sz_tkns = %zu, sampled = %d, match_pos = %zu\n", __func__,
cur_len, n, m, key_tokens.size(), sampled, match_pos);
}
if (match_pos == 0) {
return;
}
// We have a match, now we look for the statistics of the key.
size_t key_offset = map.keys.size(); // offset in the map
// We iterate through the std::vector<common_ngram_map_key> map->keys.
for (size_t i = 0; i < map.keys.size(); ++i) {
bool match = true;
for (size_t j = 0; j < n; ++j) {
if (inp[map.keys[i].key_idx + j] != key_tokens[j]) {
match = false;
break;
}
}
if (match) {
key_offset = i;
break;
}
}
if (key_offset == map.keys.size()) {
// We create a new key-entry, it will get offset key_offset.
common_ngram_map_key new_key;
new_key.key_idx = match_pos;
new_key.stat_idx = 0;
new_key.key_num = 0;
for (int i = 0; i < COMMON_NGRAM_MAX_VALUES; ++i) {
new_key.values[i].value_num = 0;
new_key.values[i].n_accepted = m;
}
map.keys.push_back(new_key);
}
// our key n-gram:
common_ngram_map_key & curr_key = map.keys[key_offset];
// update number of key hits
curr_key.key_num = (uint16_t) std::min((int) map.keys[key_offset].key_num + 1,
(int) COMMON_NGRAM_MAX_VALUE_COUNT);
if (map.key_only) {
// simple mode:
// Fill in the draft with the m tokens following the key.
// We work with value values[0] only.
int n_draft_tokens = std::min((int) m, (int) curr_key.values[0].n_accepted);
for (int i = 0; i < n_draft_tokens; ++i) {
draft.push_back(inp[match_pos + n + i]);
}
LOG_INF("%s: key_offset = %zu, key_num = %d, draft.size = %zu\n", __func__,
key_offset, curr_key.key_num, draft.size());
map.last_draft_created = false;
map.last_draft_key_idx = key_offset;
map.last_draft_value_idx = 0; // value 0 is used for simple mode
return;
}
if (curr_key.key_num < map.min_hits) {
// not enough hits to consider this a good draft
LOG_DBG("%s: key_offset = %zu, key_num = %d, min_hits = %d, no draft\n", __func__,
key_offset, curr_key.key_num, map.min_hits);
return;
}
// complex mode: examine the different m-grams after this key n-gram.
//
// determine all (max COMMON_NGRAM_MAX_VALUES) m-grams after the key n-gram.
for (size_t i = curr_key.stat_idx; i <= match_pos; ++i) {
// begins the key n-gram at index i?
bool match_key = true;
for (size_t k = 0; k < n; ++k) {
if (inp[i + k] != key_tokens[k]) {
match_key = false;
break;
}
}
if (!match_key) {
continue;
}
// Do we haven a existing value m-gram or a new one after the key at index i?
size_t idx_begin_value_key = i + n;
int idx_value = -1;
for (int v = 0; v < COMMON_NGRAM_MAX_VALUES; ++v) {
size_t idx_begin_value_v = curr_key.values[v].value_idx;
if (idx_begin_value_v == 0) {
// We found an empty value slot => we found a new value m-gram after the key n-gram.
curr_key.values[v].value_idx = idx_begin_value_key;
curr_key.values[v].value_num = 0;
curr_key.values[v].n_accepted = m;
idx_value = v;
break;
}
bool match = true;
for (size_t j = 0; j < m; ++j) {
if (inp[idx_begin_value_key + j] != inp[idx_begin_value_v + j]) {
match = false;
break;
}
}
if (match) {
// We found an existing value m-gram after the key n-gram.
idx_value = v;
break;
}
}
if (idx_value >= 0) {
// We found a value m-gram of the key n-gram.
curr_key.values[idx_value].value_num = (uint16_t) std::min((int) curr_key.values[idx_value].value_num + 1,
(int) COMMON_NGRAM_MAX_VALUE_COUNT);
}
}
// the statistics are updated up to match_pos.
curr_key.stat_idx = match_pos;
// Do we have a value we could use for the draft?
uint16_t max_occur = 0;
int slot_max = 0;
for (int v = 0; v < COMMON_NGRAM_MAX_VALUES; ++v) {
uint16_t curr_occur = curr_key.values[v].value_num;
if (curr_occur > max_occur) {
max_occur = curr_occur;
slot_max = v;
}
}
// What is sum of the other occurences?
uint32_t sum_occur = 0;
for (int v = 0; v < COMMON_NGRAM_MAX_VALUES; ++v) {
if (v == slot_max) {
continue;
}
uint16_t curr_occur = curr_key.values[v].value_num;
sum_occur += curr_occur;
}
LOG_INF("%s: key_offset = %zu, max_occur = %d, sum_occur = %d, slot_max = %d [%zu/%d, %zu/%d, %zu/%d, %zu/%d]\n", __func__,
key_offset,
max_occur, sum_occur, slot_max,
curr_key.values[0].value_idx, curr_key.values[0].value_num,
curr_key.values[1].value_idx, curr_key.values[1].value_num,
curr_key.values[2].value_idx, curr_key.values[2].value_num,
curr_key.values[3].value_idx, curr_key.values[3].value_num
);
// Print the tokens of the four values (if idx != 0), use LOG_INF
for (int v = 0; v < COMMON_NGRAM_MAX_VALUES; ++v) {
if (curr_key.values[v].value_idx != 0) {
LOG_INF("%s: value[%d] = %s\n", __func__, v, common_tokens_to_str(inp, curr_key.values[v].value_idx, m).c_str());
}
}
if (sum_occur > 0 && max_occur < 3 * sum_occur) {
// The most frequent value is not much more frequent than the other values.
// We do not use the draft.
return;
}
// We use the most frequent value values[slot_max] for the draft.
// Fill in the draft with the m tokens following the key.
int n_draft_tokens = std::min((int) m, (int) curr_key.values[slot_max].n_accepted);
for (int i = 0; i < n_draft_tokens; ++i) {
draft.push_back(inp[match_pos + n + i]);
}
LOG_INF("%s: key_offset = %zu, slot_max = %d, key_num = %d, draft.size = %zu\n", __func__,
key_offset, slot_max,
curr_key.key_num, draft.size());
map.last_draft_created = true;
map.last_draft_key_idx = key_offset;
map.last_draft_value_idx = slot_max; // value used for draft generation.
}
void common_ngram_map_accept(common_ngram_map & map, uint16_t n_accepted) {
if (!map.last_draft_created) {
return;
}
// find the key and its chosen value.
const size_t key_idx = map.last_draft_key_idx;
const size_t val_idx = map.last_draft_value_idx;
// find key corresponding to key_idx.
common_ngram_map_key & curr_key = map.keys[key_idx];
// find value corresponding to val_idx.
struct common_ngram_map_value & curr_value = curr_key.values[val_idx]; // value used for draft generation.
// update the value statistics
LOG_INF("common_ngram_map_send_accepted: n_accepted = %d, prev value_num = %d\n",
n_accepted, curr_value.n_accepted);
curr_value.n_accepted = n_accepted;
}
// Helper functions.
//
// Print the values of a sublist of `llama_tokens & inp` to a string in the form [v0, v1, v2, ...].
std::string common_tokens_to_str(const llama_tokens & inp, size_t start, size_t length) {
std::ostringstream oss;
oss << '[';
for (size_t i = 0; i < length; ++i) {
if (i > 0) {
oss << ", ";
}
oss << inp[start + i];
}
oss << ']';
return oss.str();
}

105
common/ngram-map.h Normal file
View file

@ -0,0 +1,105 @@
#pragma once
//
// common/ngram-map.h: structures used to manage a map from n-grams to a list of m-grams
//
// These structures are used to do a lookup of n-grams followed by m-grams in token history.
//
// There are two algorithms implemented:
// 1. ngram_simple: lookup of n-grams followed by m-grams in token history.
// 2. ngram_map: lookup of n-grams followed by m-grams in token history using a map.
// The map is a vector of key n-grams, and for each key n-gram there is a list of value m-grams.
//
#include "llama.h"
#include <vector>
// n-gram simple
//
// config of n-gram simple.
struct common_ngram_simple_config {
uint16_t size_ngram; // size of n-grams to lookup in self-mode
uint16_t size_mgram; // size of m-grams to draft in self-mode
uint16_t check_rate; // check for speculative decoding without draft model for each check_rate token
};
// current state (and config) of n-gram simple.
struct common_ngram_simple_state {
common_ngram_simple_config config;
size_t idx_last_check = 0; // index of last check in context history (mutable)
common_ngram_simple_state(const common_ngram_simple_config & config)
: config(config) {}
};
// Searches for a n-gram in the history and checks whether a draft sequence should be generated.
// state: the ngram simple state to search in.
// inp: the tokens generated so far.
// sampled: the token that was just sampled.
// draft: vector to store the draft tokens, initially empty.
llama_tokens common_ngram_simple_draft(
common_ngram_simple_state & state,
const llama_tokens & tokens, llama_token sampled);
// n-gram map
//
// maximum number of m-gram values stored for each key n-gram.
#define COMMON_NGRAM_MAX_VALUES 4
// statistics of a m-gram after a known n-gram
struct common_ngram_map_value {
size_t value_idx = 0; // index of value m-gram in token-history (0 if unused)
uint16_t value_num = 0; // number of occurences of this value m-gram after the key n-gram (0 in an unused values-slot)
int16_t n_accepted = -1; // number of accepted tokens at last draft (-1 if unused)
};
// statistics of a n-gram
struct common_ngram_map_key {
size_t key_idx; // index of key n-gram in token-history
size_t stat_idx; // index of last token of stastistics computation (key_num, values)
uint16_t key_num; // number of occurences of this key n-gram in token-history
common_ngram_map_value values[COMMON_NGRAM_MAX_VALUES]; // some known values after the key
};
// map from n-grams to following m-grams in token-history
struct common_ngram_map {
uint16_t size_key; // size of key n-grams
uint16_t size_value; // size of value m-grams
bool key_only; // true if only key n-grams are used, no values.
// first draft: vector only, no map.
std::vector<common_ngram_map_key> keys; // key n-grams which occur several times in token-history
uint16_t check_rate; // check for speculative decoding without draft model for each check_rate token
uint16_t min_hits; // minimum number of key hits to consider a draft
common_ngram_map(uint16_t sz_key, uint16_t sz_value, bool only_keys,
uint16_t check_rate, uint16_t min_hits)
: size_key(sz_key), size_value(sz_value), key_only(only_keys),
check_rate(check_rate), min_hits(min_hits) {}
bool last_draft_created = false; // true if a draft was created at last call.
size_t last_draft_key_idx = 0; // index of last key used for draft generation.
uint16_t last_draft_value_idx = 0; // index of last value used for draft generation.
size_t idx_last_check = 0; // index of last check in context history
};
// Searches for the n-gram in the history and checks whether a draft sequence should be generated.
// map: the ngram map to search in.
// inp: the tokens generated so far.
// sampled: the token that was just sampled.
// draft: vector to store the draft tokens, initially empty.
void common_ngram_map_draft(
common_ngram_map & map,
const llama_tokens & inp, llama_token sampled,
llama_tokens & draft);
// Update the statistics of a value after a draft was processed.
void common_ngram_map_accept(common_ngram_map & map, uint16_t n_accepted);

File diff suppressed because it is too large Load diff

View file

@ -5,31 +5,33 @@
struct common_speculative;
struct common_speculative_params {
int n_draft = 16; // max drafted tokens
int n_reuse = 256;
// comma separated list of all types
std::string common_speculative_type_name_str();
float p_min = 0.75f; // min probability required to accept a token in the draft
};
// convert string to type
enum common_speculative_type common_speculative_type_from_name(const std::string & name);
struct common_speculative * common_speculative_init(
struct llama_context * ctx_tgt,
struct llama_context * ctx_dft
);
// convert type to string
std::string common_speculative_type_to_str(enum common_speculative_type type);
void common_speculative_free(struct common_speculative * spec);
common_speculative * common_speculative_init(
const common_params_speculative & params,
llama_context * ctx_tgt);
bool common_speculative_are_compatible(
const struct llama_context * ctx_tgt,
const struct llama_context * ctx_dft);
void common_speculative_free(common_speculative * spec);
void common_speculative_add_replacement_tgt_dft(
struct common_speculative * spec,
const char *source, const char *dest);
// optionally call once at the beginning of a new generation
void common_speculative_begin(common_speculative * spec, const llama_tokens & prompt);
// sample up to n_draft tokens and add them to the batch using the draft model
llama_tokens common_speculative_gen_draft(
struct common_speculative * spec,
struct common_speculative_params params,
const llama_tokens & prompt,
llama_token id_last);
llama_tokens common_speculative_draft(
common_speculative * spec,
const common_params_speculative & params,
const llama_tokens & prompt,
llama_token id_last);
// informs the speculative decoder that n_accepted tokens were accepted by the target model
void common_speculative_accept(common_speculative * spec, uint16_t n_accepted);
// print statistics about the speculative decoding
void common_speculative_print_stats(const common_speculative * spec);

View file

@ -8912,13 +8912,16 @@ class GraniteHybridModel(Mamba2Model, GraniteMoeModel):
name.endswith("block_sparse_moe.input_linear.weight")
or "shared_mlp" in name
):
return GraniteMoeModel.modify_tensors(self, data_torch, name, bid)
yield from GraniteMoeModel.modify_tensors(self, data_torch, name, bid)
return
# Determine whether this is a mamba layer or an attention layer
if bid in self._ssm_layers:
return Mamba2Model.modify_tensors(self, data_torch, name, bid)
yield from Mamba2Model.modify_tensors(self, data_torch, name, bid)
return
elif bid in self._attn_layers:
return GraniteMoeModel.modify_tensors(self, data_torch, name, bid)
yield from GraniteMoeModel.modify_tensors(self, data_torch, name, bid)
return
yield from ModelBase.modify_tensors(self, data_torch, name, bid)
def set_gguf_parameters(self):

View file

@ -3092,63 +3092,166 @@ static bool ggml_cuda_should_fuse_rope_set_rows(const ggml_tensor * rope,
return true;
}
static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops, std::initializer_list<enum ggml_unary_op> unary_ops) {
static bool ggml_cuda_topk_moe_fusion(const struct ggml_cgraph * cgraph, int node_idx, ggml_cuda_topk_moe_args & args) {
args.sigmoid = false;
args.softmax = false;
args.delayed_softmax = false;
args.prob_bias = false;
args.norm = false;
const int n_nodes = cgraph->n_nodes;
ggml_tensor ** nodes = cgraph->nodes;
if (nodes[node_idx]->op == GGML_OP_SOFT_MAX) {
args.softmax = true;
}
if (nodes[node_idx]->op == GGML_OP_UNARY) {
if (ggml_get_unary_op(nodes[node_idx]) != GGML_UNARY_OP_SIGMOID) {
return false;
}
args.sigmoid = true;
}
if (nodes[node_idx]->op == GGML_OP_ARGSORT) {
args.delayed_softmax = true;
}
node_idx++;
if (args.sigmoid || args.softmax) {
// SOFTMAX -> RESHAPE
if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_RESHAPE ||
nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
return false;
}
ggml_tensor * probs_reshaped = nodes[node_idx];
node_idx++;
if (node_idx >= n_nodes) {
return false;
}
// src of bias add is the unreshaped probs (-2 instead of -1)
if (nodes[node_idx]->op == GGML_OP_ADD && nodes[node_idx]->src[0] == nodes[node_idx - 2]) {
args.prob_bias = true;
node_idx++;
}
// RESHAPE/ADD -> ARGSORT
if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_ARGSORT) {
return false;
}
if (args.prob_bias && nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
return false;
} else if (!args.prob_bias && nodes[node_idx]->src[0] != nodes[node_idx - 2]) {
return false;
}
node_idx++;
// ARGSORT-> VIEW
if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_VIEW ||
nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
return false;
}
node_idx++;
if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_GET_ROWS) {
return false;
}
// GET_ROWS
if (nodes[node_idx]->src[0] != probs_reshaped || nodes[node_idx]->src[1] != nodes[node_idx - 1]) {
return false;
}
node_idx++;
} else if (args.delayed_softmax) {
if (node_idx - 2 < 0) {
return false;
}
ggml_tensor * probs_reshaped = nodes[node_idx - 2];
// VIEW->ARGSORT
if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_VIEW ||
nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
return false;
}
node_idx++;
// GET_ROWS
if (node_idx >= n_nodes || nodes[node_idx]->src[1] != nodes[node_idx - 1] ||
nodes[node_idx]->src[0] != probs_reshaped) {
return false;
}
node_idx++;
static const std::vector<ggml_op> remaining_ops = { GGML_OP_RESHAPE, GGML_OP_SOFT_MAX, GGML_OP_RESHAPE };
for (const ggml_op op : remaining_ops) {
if (node_idx >= n_nodes || nodes[node_idx]->op != op || nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
return false;
}
node_idx++;
}
}
// At this point we can check for norm + scale. Everything is now at least valid till the norm
if (node_idx >= n_nodes) {
return true;
}
if (nodes[node_idx]->op == GGML_OP_RESHAPE) {
//check RESHAPE->SUM_ROWS->CLAMP->DIV->RESHAPE
static const std::vector<ggml_op> norm_ops = { GGML_OP_RESHAPE, GGML_OP_SUM_ROWS, GGML_OP_CLAMP };
args.norm = true;
for (const ggml_op op : norm_ops) {
if (nodes[node_idx]->op == op && nodes[node_idx]->src[0] == nodes[node_idx - 1]) {
node_idx++;
} else {
args.norm = false;
return true;
}
}
// DIV <- CLAMP, RESHAPE
if (nodes[node_idx]->op != GGML_OP_DIV || nodes[node_idx]->src[1] != nodes[node_idx - 1] ||
nodes[node_idx]->src[0] != nodes[node_idx - 3]) {
args.norm = false;
return true;
}
node_idx++;
if (nodes[node_idx]->op != GGML_OP_RESHAPE || nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
args.norm = false;
return true;
}
node_idx++;
}
if (nodes[node_idx]->op == GGML_OP_SCALE && nodes[node_idx]->src[0] == nodes[node_idx - 1]) {
args.scale = true;
}
return true;
}
static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph,
int node_idx,
std::initializer_list<enum ggml_op> ops,
std::initializer_list<enum ggml_unary_op> unary_ops) {
#ifndef NDEBUG
const size_t num_unary = std::count(ops.begin(), ops.end(), GGML_OP_UNARY);
GGML_ASSERT(unary_ops.size() == num_unary);
#endif
//TODO: remove special case once ggml_can_fuse can handle empty nodes
std::initializer_list<enum ggml_op> topk_moe_ops =
ggml_cuda_topk_moe_ops(/*with_norm*/ false, /*delayed_softmax=*/false);
std::initializer_list<enum ggml_op> topk_moe_ops_with_norm =
ggml_cuda_topk_moe_ops(/*with_norm=*/true, /*delayed_softmax=*/false);
std::initializer_list<enum ggml_op> topk_moe_ops_delayed_softmax =
ggml_cuda_topk_moe_ops(/*with_norm=*/false, /*delayed_softmax=*/true);
const auto is_equal = [](const std::initializer_list<enum ggml_op> & list1,
const std::initializer_list<enum ggml_op> & list2) {
return std::equal(list1.begin(), list1.end(), list2.begin(), list2.end());
};
if (is_equal(topk_moe_ops_with_norm, ops) &&
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 9 })) {
ggml_tensor * softmax = cgraph->nodes[node_idx];
ggml_tensor * weights = cgraph->nodes[node_idx + 9];
ggml_tensor * get_rows = cgraph->nodes[node_idx + 4];
ggml_tensor * argsort = cgraph->nodes[node_idx + 2];
int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0];
if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) {
return true;
}
}
if (is_equal(topk_moe_ops, ops) && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 4 })) {
ggml_tensor * softmax = cgraph->nodes[node_idx];
ggml_tensor * weights = cgraph->nodes[node_idx + 4];
ggml_tensor * get_rows = cgraph->nodes[node_idx + 4];
ggml_tensor * argsort = cgraph->nodes[node_idx + 2];
int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0];
if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) {
return true;
}
}
if (is_equal(topk_moe_ops_delayed_softmax, ops) &&
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 1, node_idx + 5 })) {
ggml_tensor * softmax = cgraph->nodes[node_idx + 4];
ggml_tensor * weights = cgraph->nodes[node_idx + 5];
ggml_tensor * get_rows = cgraph->nodes[node_idx + 2];
ggml_tensor * argsort = cgraph->nodes[node_idx + 0];
int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0];
if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) {
return true;
}
}
std::initializer_list<enum ggml_op> mul_mat_bias_glu_ops = { GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_GLU };
std::initializer_list<enum ggml_op> mul_mat_id_bias_glu_ops = { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_GLU };
@ -3410,35 +3513,75 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
// start of fusion operations
static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr);
if (!disable_fusion) {
ggml_cuda_topk_moe_args args;
if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ true), {})) {
ggml_tensor * weights = cgraph->nodes[i + 9];
ggml_tensor * selected_experts = cgraph->nodes[i + 3];
ggml_tensor * clamp = cgraph->nodes[i + 7];
ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ true,
/*delayed softmax*/ false, clamp);
i += 9;
continue;
}
if (cgraph->nodes[i]->op == GGML_OP_UNARY || cgraph->nodes[i]->op == GGML_OP_SOFT_MAX ||
cgraph->nodes[i]->op == GGML_OP_ARGSORT) {
const bool can_fuse = ggml_cuda_topk_moe_fusion(cgraph, i, args);
if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ false), {})) {
ggml_tensor * weights = cgraph->nodes[i + 4];
ggml_tensor * selected_experts = cgraph->nodes[i + 3];
ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ false,
/*delayed softmax*/ false);
i += 4;
continue;
}
std::vector<ggml_op> ops;
if (ggml_cuda_can_fuse(cgraph, i,
ggml_cuda_topk_moe_ops(/*with norm*/ false, /*delayed softmax*/ true), {})) {
ggml_tensor * weights = cgraph->nodes[i + 5];
ggml_tensor * ids = cgraph->nodes[i + 1];
if (can_fuse) {
const ggml_tensor * logits = node->src[0];
ggml_tensor * weights = nullptr;
ggml_tensor * ids = nullptr;
const ggml_tensor * bias = nullptr;
const ggml_tensor * clamp = nullptr;
const ggml_tensor * scale = nullptr;
ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, ids, /*with norm*/ false,
/*delayed_softmax*/ true);
i += 5;
continue;
if (!args.delayed_softmax) {
ggml_op gating_op = args.sigmoid ? GGML_OP_UNARY : GGML_OP_SOFT_MAX;
int out_nodes[2]; // nodes which can't be elided
if (args.prob_bias) {
bias = cgraph->nodes[i + 2]->src[1];
ops.insert(ops.end(), { gating_op, GGML_OP_RESHAPE, GGML_OP_ADD, GGML_OP_ARGSORT,
GGML_OP_VIEW, GGML_OP_GET_ROWS });
out_nodes[0] = i + 4;
ids = cgraph->nodes[i + 4];
} else {
ops.insert(ops.end(), { gating_op, GGML_OP_RESHAPE, GGML_OP_ARGSORT, GGML_OP_VIEW,
GGML_OP_GET_ROWS });
out_nodes[0] = i + 3;
ids = cgraph->nodes[i + 3];
}
if (args.norm) {
ops.insert(ops.end(), { GGML_OP_RESHAPE, GGML_OP_SUM_ROWS, GGML_OP_CLAMP,
GGML_OP_DIV, GGML_OP_RESHAPE });
clamp = cgraph->nodes[i + ops.size() - 3];
}
if (args.scale) {
ops.insert(ops.end(), { GGML_OP_SCALE });
scale = cgraph->nodes[i + ops.size() - 1];
}
weights = cgraph->nodes[i + ops.size() - 1];
out_nodes[1] = i + ops.size() - 1;
if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) &&
ggml_cuda_should_use_topk_moe(node, logits, weights, ids)) {
ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args);
i += ops.size() - 1;
continue;
}
} else if (!args.norm && !args.prob_bias) {
//special case gpt-oss, no norm, no bias.
ops.insert(ops.end(), { GGML_OP_ARGSORT, GGML_OP_VIEW, GGML_OP_GET_ROWS,
GGML_OP_RESHAPE, GGML_OP_SOFT_MAX, GGML_OP_RESHAPE });
weights = cgraph->nodes[i + 5];
ids = cgraph->nodes[i + 1];
const ggml_tensor * softmax = cgraph->nodes[i + 4];
int out_nodes[2] = { i + 1, i + 5 };
if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) &&
ggml_cuda_should_use_topk_moe(softmax, logits, weights, ids)) {
ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args);
i += ops.size() - 1;
continue;
}
}
}
}
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, {})) {

View file

@ -333,7 +333,33 @@ namespace ggml_cuda_mma {
static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 16 && J == 8) {
return 4 * (threadIdx.x / 16) + l;
return ne * (threadIdx.x / 16) + l;
} else {
NO_DEVICE_CODE;
return -1;
}
}
#elif defined(AMD_MFMA_AVAILABLE)
static constexpr int ne = I * J / 64;
half2 x[ne] = {{0.0f, 0.0f}};
static constexpr __device__ bool supported() {
if (I == 16 && J == 8) return true;
return false;
}
static __device__ __forceinline__ int get_i(const int l) {
if constexpr (I == 16 && J == 8) {
return threadIdx.x % 16;
} else {
NO_DEVICE_CODE;
return -1;
}
}
static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 16 && J == 8) {
return ne * (threadIdx.x / 16) + l;
} else {
NO_DEVICE_CODE;
return -1;
@ -391,7 +417,22 @@ namespace ggml_cuda_mma {
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR;
#if defined(AMD_WMMA_AVAILABLE)
static constexpr int ne = I * J / 32;
static constexpr int ne = tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::ne;
nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
static constexpr __device__ bool supported() {
return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::supported();
}
static __device__ __forceinline__ int get_i(const int l) {
return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::get_i(l);
}
static __device__ __forceinline__ int get_j(const int l) {
return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::get_j(l);
}
#elif defined(AMD_MFMA_AVAILABLE)
static constexpr int ne = tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::ne;
nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
static constexpr __device__ bool supported() {
@ -945,6 +986,32 @@ namespace ggml_cuda_mma {
#endif // AMPERE_MMA_AVAILABLE
}
template <data_layout dl_ab, data_layout dl_d>
static __device__ __forceinline__ void mma(
tile<16, 16, float, dl_d> & D, const tile<16, 8, float, dl_ab> & A, const tile<16, 8, float, dl_ab> & B) {
#ifdef AMD_MFMA_AVAILABLE
using floatx4_t = __attribute__((ext_vector_type(4))) float;
floatx4_t& acc_frag = reinterpret_cast<floatx4_t&>(D.x[0]);
#if defined(CDNA3)
using floatx2_t = __attribute__((ext_vector_type(2))) float;
const floatx2_t& a_frag = reinterpret_cast<const floatx2_t&>(A.x[0]);
const floatx2_t& b_frag = reinterpret_cast<const floatx2_t&>(B.x[0]);
acc_frag = __builtin_amdgcn_mfma_f32_16x16x8_xf32(a_frag, b_frag, acc_frag, 0, 0, 0);
#elif defined(CDNA2) || defined(CDNA1)
#pragma unroll
for (int i = 0; i < 2; ++i) {
acc_frag = __builtin_amdgcn_mfma_f32_16x16x4f32(A.x[i], B.x[i], acc_frag, 0, 0, 0);
}
#else
GGML_UNUSED_VARS(D, A, B);
NO_DEVICE_CODE;
#endif // defined(CDNA3)
#else
GGML_UNUSED_VARS(D, A, B);
NO_DEVICE_CODE;
#endif // AMD_MFMA_AVAILABLE
}
static __device__ __forceinline__ void mma_block_scaled(tile<16, 8, float> & D,
const tile<16, 8, int> & A,
const tile<8, 8, int> & B,
@ -1054,6 +1121,13 @@ namespace ggml_cuda_mma {
GGML_UNUSED_VARS(D, A, B);
NO_DEVICE_CODE;
#endif // RDNA4
#elif defined(AMD_MFMA_AVAILABLE)
using halfx4_t = __attribute__((ext_vector_type(4))) _Float16;
using floatx4_t = __attribute__((ext_vector_type(4))) float;
floatx4_t& acc_frag = reinterpret_cast<floatx4_t&>(D.x[0]);
const halfx4_t& a_frag = reinterpret_cast<const halfx4_t&>(A.x[0]);
const halfx4_t& b_frag = reinterpret_cast<const halfx4_t&>(B.x[0]);
acc_frag = __builtin_amdgcn_mfma_f32_16x16x16f16(a_frag, b_frag, acc_frag, 0, 0, 0);
#else
GGML_UNUSED_VARS(D, A, B);
NO_DEVICE_CODE;
@ -1081,11 +1155,31 @@ namespace ggml_cuda_mma {
#else
GGML_UNUSED_VARS(D, A, B);
NO_DEVICE_CODE;
#endif // RDNA4
#endif // defined(RDNA4)
#elif defined(AMD_MFMA_AVAILABLE)
using floatx4_t = __attribute__((ext_vector_type(4))) float;
floatx4_t& acc_frag = reinterpret_cast<floatx4_t&>(D.x[0]);
#if defined(CDNA3) || defined(CDNA2)
using bf16x4_t = __attribute__((ext_vector_type(4))) __bf16;
const bf16x4_t& a_frag = reinterpret_cast<const bf16x4_t&>(A.x[0]);
const bf16x4_t& b_frag = reinterpret_cast<const bf16x4_t&>(B.x[0]);
acc_frag = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_frag, b_frag, acc_frag, 0, 0, 0);
#elif defined(CDNA1)
#pragma unroll
for (int i = 0; i < 2; ++i) {
using bf16x2_t = __attribute__((ext_vector_type(2))) __bf16;
const bf16x2_t& a_frag = reinterpret_cast<const bf16x2_t&>(A.x[i]);
const bf16x2_t& b_frag = reinterpret_cast<const bf16x2_t&>(B.x[i]);
acc_frag = __builtin_amdgcn_mfma_f32_16x16x8bf16(a_frag, b_frag, acc_frag, 0, 0, 0);
}
#else
GGML_UNUSED_VARS(D, A, B);
NO_DEVICE_CODE;
#endif // AMPERE_MMA_AVAILABLE
#endif // defined(CDNA3) || defined(CDNA2)
#else
GGML_UNUSED_VARS(D, A, B);
NO_DEVICE_CODE;
#endif // defined(AMD_WMMA_AVAILABLE)
}
template <data_layout dl_d, data_layout dl_ab>

View file

@ -2,6 +2,13 @@
#include "mmf.cuh"
#include "mmid.cuh"
static __forceinline__ int mmf_get_rows_per_block(const int cc) {
if (GGML_CUDA_CC_IS_CDNA(cc)) {
return MMF_ROWS_PER_BLOCK_CDNA;
} else {
return MMF_ROWS_PER_BLOCK;
}
}
void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
GGML_ASSERT( src1->type == GGML_TYPE_F32);
@ -89,28 +96,32 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
ids_info_ptr = &ids_info;
}
const int device = ggml_cuda_get_device();
const int cc = ggml_cuda_info().devices[device].cc;
const int rows_per_block = mmf_get_rows_per_block(cc);
switch (src0->type) {
case GGML_TYPE_F32: {
const float * src0_d = (const float *) src0->data;
constexpr int vals_per_T = 1;
mul_mat_f_switch_cols_per_block(
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
mul_mat_f_switch_rows_per_block<float>(
rows_per_block, src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
} break;
case GGML_TYPE_F16: {
const half2 * src0_d = (const half2 *) src0->data;
constexpr int vals_per_T = 2;
mul_mat_f_switch_cols_per_block(
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
mul_mat_f_switch_rows_per_block<half2>(
rows_per_block, src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
} break;
case GGML_TYPE_BF16: {
const nv_bfloat162 * src0_d = (const nv_bfloat162 *) src0->data;
constexpr int vals_per_T = 2;
mul_mat_f_switch_cols_per_block(
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
mul_mat_f_switch_rows_per_block<nv_bfloat162>(
rows_per_block, src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
} break;
@ -140,7 +151,11 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
return false;
}
}
if (src0_ne[1] % MMF_ROWS_PER_BLOCK != 0) {
if (src0_ne[1] % mmf_get_rows_per_block(cc) != 0) {
return false;
}
if (GGML_CUDA_CC_IS_CDNA3(cc) && type == GGML_TYPE_BF16) {
return false;
}
@ -153,6 +168,11 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
} else {
if (GGML_CUDA_CC_IS_RDNA3_0(cc) && src1_ncols > 8) {
return false;
} else if (GGML_CUDA_CC_IS_CDNA2(cc) && (type == GGML_TYPE_F16 || type == GGML_TYPE_BF16)) {
//TODO: truse CDNA2 as CDNA1, tune the perf when CDNA2 is available.
return false;
} else if (GGML_CUDA_CC_IS_CDNA1(cc) && (type == GGML_TYPE_F16 || type == GGML_TYPE_BF16)) {
return false;
} else if (src1_ncols > 16) {
return false;
}
@ -160,11 +180,11 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
switch (type) {
case GGML_TYPE_F32:
return ampere_mma_available(cc);
return ampere_mma_available(cc) || amd_mfma_available(cc);
case GGML_TYPE_F16:
return volta_mma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc);
return volta_mma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc) || amd_mfma_available(cc);
case GGML_TYPE_BF16:
return ampere_mma_available(cc) || amd_wmma_available(cc);
return ampere_mma_available(cc) || amd_wmma_available(cc) || amd_mfma_available(cc);
default:
return false;
}

View file

@ -7,6 +7,31 @@
using namespace ggml_cuda_mma;
#define MMF_ROWS_PER_BLOCK 32
#define MMF_ROWS_PER_BLOCK_CDNA 64
static __forceinline__ int64_t mmf_get_max_block_size(int cc) {
if (GGML_CUDA_CC_IS_CDNA(cc)) {
return 512;
} else {
return 256;
}
}
static __forceinline__ int mmf_get_padding(int cc) {
if (GGML_CUDA_CC_IS_CDNA(cc)) {
return 2;
} else {
return 4;
}
}
static constexpr __device__ int mmf_get_padding() {
#if defined(AMD_MFMA_AVAILABLE)
return 2;
#else
return 4;
#endif // defined(AMD_MFMA_AVAILABLE)
}
struct mmf_ids_data {
const int32_t * ids_src_compact = nullptr;
@ -29,23 +54,25 @@ static __global__ void mul_mat_f(
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
// TODO: handle this in a consistent and simpler way after AMD MFMA support has been added
#if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
#if defined(AMD_WMMA_AVAILABLE)
// Special case for tf32, just dummy mma layout as wmma doesn't support it.
constexpr bool is_tf32 = std::is_same_v<T, float>;
constexpr int tile_B_I = is_tf32 ? 8 : 16;
constexpr int tile_C_J = is_tf32 ? 8 : 16;
constexpr data_layout ab_layout = is_tf32 ? DATA_LAYOUT_I_MAJOR : get_input_data_layout();
typedef tile<16, 8, T, ab_layout> tile_A;
typedef tile<tile_B_I, 8, T, ab_layout> tile_B;
typedef tile<16, tile_C_J, float, DATA_LAYOUT_J_MAJOR> tile_C;
if constexpr (!(std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
typedef tile<16, 8, T, get_input_data_layout()> tile_A;
typedef tile<16, 8, T, get_input_data_layout()> tile_B;
typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C;
#elif defined(AMD_MFMA_AVAILABLE)
if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK_CDNA) {NO_DEVICE_CODE;} else {
typedef tile<16, 8, T, DATA_LAYOUT_I_MAJOR> tile_A;
typedef tile<16, 8, T, DATA_LAYOUT_I_MAJOR> tile_B;
typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C;
#else
#ifdef VOLTA_MMA_AVAILABLE
if constexpr (!std::is_same_v<T, half2>) {NO_DEVICE_CODE;} else {
if constexpr (!std::is_same_v<T, half2> || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
typedef tile<32, 4, T, DATA_LAYOUT_I_MAJOR> tile_A;
typedef tile< 8, 4, T, DATA_LAYOUT_I_MAJOR_MIRRORED> tile_B;
typedef tile<32, 8, float, DATA_LAYOUT_I_MAJOR> tile_C;
#else
if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
typedef tile<16, 8, T> tile_A;
typedef tile<8, 8, T> tile_B;
typedef tile<16, 8, float> tile_C;
@ -57,7 +84,7 @@ static __global__ void mul_mat_f(
}
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
constexpr int tile_k_padded = warp_size + 4;
constexpr int tile_k_padded = warp_size + mmf_get_padding();
constexpr int ntA = rows_per_block / tile_A::I;
constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I;
@ -198,7 +225,7 @@ static __global__ void mul_mat_f(
}
float * buf_iw = (float *) compute_base;
constexpr int kiw = nwarps*rows_per_block + 4;
constexpr int kiw = nwarps*rows_per_block + mmf_get_padding();
if (nwarps > 1) {
__syncthreads();
@ -228,27 +255,34 @@ static __global__ void mul_mat_f(
return;
}
float sum = 0.0f;
static_assert(rows_per_block == warp_size, "need loop/check");
float sum[rows_per_block/warp_size] = {0.0f};
static_assert((rows_per_block % warp_size) == 0, "rows_per_block must be a multiple of warp_size.");
#pragma unroll
for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) {
const int i = i0 + threadIdx.x;
#pragma unroll
for (int i1 = 0; i1 < sizeof(sum)/sizeof(sum[0]); ++i1) {
const int i = i0 + i1*warp_size + threadIdx.x;
sum += buf_iw[j*kiw + i];
sum[i1] += buf_iw[j*kiw + i];
}
}
if constexpr (!has_ids) {
dst[j*stride_col_dst + row0 + threadIdx.x] = sum;
#pragma unroll
for (int i0 = 0; i0 < sizeof(sum)/sizeof(sum[0]); ++i0) {
dst[j*stride_col_dst + row0 + i0*warp_size + threadIdx.x] = sum[i0];
}
} else {
const int slot = (j < cols_per_block) ? slot_map[j] : -1;
if (slot >= 0 && (col_base + j) < ncols_dst_total) {
dst[slot*stride_channel_dst + j*stride_col_dst + row0 + threadIdx.x] = sum;
#pragma unroll
for (int i0 = 0; i0 < sizeof(sum)/sizeof(sum[0]); ++i0) {
dst[slot*stride_channel_dst + j*stride_col_dst + row0 + i0*warp_size + threadIdx.x] = sum[i0];
}
}
}
}
#ifdef VOLTA_MMA_AVAILABLE
}
#endif //VOLTA_MMA_AVAILABLE
#else
GGML_UNUSED_VARS(x, y, ids, dst,
ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
@ -256,7 +290,7 @@ static __global__ void mul_mat_f(
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
NO_DEVICE_CODE;
#endif // (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
}
//This kernel is for larger batch sizes of mul_mat_id
@ -271,23 +305,25 @@ static __global__ void mul_mat_f_ids(
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
const uint3 sis1_fd, const uint3 nch_fd) {
// TODO: handle this in a consistent and simpler way after AMD MFMA support has been added
#if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
#if defined(AMD_WMMA_AVAILABLE)
// Special case for tf32, just dummy mma layout as wmma doesn't support it.
constexpr bool is_tf32 = std::is_same_v<T, float>;
constexpr int tile_B_I = is_tf32 ? 8 : 16;
constexpr int tile_C_J = is_tf32 ? 8 : 16;
constexpr data_layout ab_layout = is_tf32 ? DATA_LAYOUT_I_MAJOR : get_input_data_layout();
typedef tile<16, 8, T, ab_layout> tile_A;
typedef tile<tile_B_I, 8, T, ab_layout> tile_B;
typedef tile<16, tile_C_J, float, DATA_LAYOUT_J_MAJOR> tile_C;
if constexpr (!(std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
typedef tile<16, 8, T, get_input_data_layout()> tile_A;
typedef tile<16, 8, T, get_input_data_layout()> tile_B;
typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C;
#elif defined(AMD_MFMA_AVAILABLE)
if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK_CDNA) {NO_DEVICE_CODE;} else {
typedef tile<16, 8, T, DATA_LAYOUT_I_MAJOR> tile_A;
typedef tile<16, 8, T, DATA_LAYOUT_I_MAJOR> tile_B;
typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C;
#else
#ifdef VOLTA_MMA_AVAILABLE
if constexpr (!std::is_same_v<T, half2>) {NO_DEVICE_CODE;} else {
if constexpr (!std::is_same_v<T, half2> || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
typedef tile<32, 4, T, DATA_LAYOUT_I_MAJOR> tile_A;
typedef tile< 8, 4, T, DATA_LAYOUT_I_MAJOR_MIRRORED> tile_B;
typedef tile<32, 8, float, DATA_LAYOUT_I_MAJOR> tile_C;
#else
if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
typedef tile<16, 8, T> tile_A;
typedef tile<8, 8, T> tile_B;
typedef tile<16, 8, float> tile_C;
@ -300,7 +336,7 @@ static __global__ void mul_mat_f_ids(
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
constexpr int tile_k_padded = warp_size + 4;
constexpr int tile_k_padded = warp_size + mmf_get_padding();
constexpr int ntA = rows_per_block / tile_A::I;
constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I;
@ -467,7 +503,7 @@ static __global__ void mul_mat_f_ids(
}
float * buf_iw = (float *) compute_base;
constexpr int kiw = nwarps*rows_per_block + 4;
constexpr int kiw = nwarps*rows_per_block + mmf_get_padding();
if (nwarps > 1) {
__syncthreads();
@ -497,13 +533,16 @@ static __global__ void mul_mat_f_ids(
return;
}
float sum = 0.0f;
static_assert(rows_per_block == warp_size, "need loop/check");
float sum[rows_per_block/warp_size] = {0.0f};
static_assert((rows_per_block % warp_size) == 0, "rows_per_block must be a multiple of warp_size.");
#pragma unroll
for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) {
const int i = i0 + threadIdx.x;
#pragma unroll
for (int i1 = 0; i1 < sizeof(sum)/sizeof(sum[0]); ++i1) {
const int i = i0 + i1*warp_size + threadIdx.x;
sum += buf_iw[j*kiw + i];
sum[i1] += buf_iw[j * kiw + i];
}
}
const int global_j = col_base + j;
@ -513,23 +552,24 @@ static __global__ void mul_mat_f_ids(
const int token = (int) qrm.x;
if (token < ncols_dst_total) {
const int slot = (int) qrm.y;
dst[slot*stride_channel_dst + token*stride_col_dst + row0 + threadIdx.x] = sum;
#pragma unroll
for (int i0 = 0; i0 < sizeof(sum)/sizeof(sum[0]); ++i0) {
dst[slot * stride_channel_dst + token * stride_col_dst + row0 + i0*warp_size + threadIdx.x] = sum[i0];
}
}
}
}
#ifdef VOLTA_MMA_AVAILABLE
}
#endif // VOLTA_MMA_AVAILABLE
#else
GGML_UNUSED_VARS(x, y, ids_src_compact, ids_dst_compact, expert_bounds, dst,
ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, sis1_fd, nch_fd);
NO_DEVICE_CODE;
#endif // (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
}
template<typename T, int cols_per_block, int nwarps>
template<typename T, int rows_per_block, int cols_per_block, int nwarps>
static inline void mul_mat_f_switch_ids(
const T * x, const float * y, const int32_t * ids, float * dst,
const int64_t ncols_x, const int64_t ncols_dst, const int64_t nchannels_dst,
@ -553,7 +593,7 @@ static inline void mul_mat_f_switch_ids(
const uint3 sis1_fd = ids_data->sis1 > 0 ? init_fastdiv_values((uint32_t) ids_data->sis1) : make_uint3(0, 0, 1);
const uint3 nch_fd = init_fastdiv_values((uint32_t) nchannels_dst);
mul_mat_f_ids<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps><<<block_nums_ids, block_dims, nbytes_shared_total, stream>>>
mul_mat_f_ids<T, rows_per_block, cols_per_block, nwarps><<<block_nums_ids, block_dims, nbytes_shared_total, stream>>>
(x, y, ids_data->ids_src_compact, ids_data->ids_dst_compact, ids_data->expert_bounds_dev, dst,
ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
@ -564,19 +604,19 @@ static inline void mul_mat_f_switch_ids(
dim3 block_nums_ids = block_nums;
block_nums_ids.y *= col_tiles;
mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true><<<block_nums_ids, block_dims, nbytes_shared_total, stream>>>
mul_mat_f<T, rows_per_block, cols_per_block, nwarps, true><<<block_nums_ids, block_dims, nbytes_shared_total, stream>>>
(x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} else {
mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, false><<<block_nums, block_dims, nbytes_shared_total, stream>>>
mul_mat_f<T, rows_per_block, cols_per_block, nwarps, false><<<block_nums, block_dims, nbytes_shared_total, stream>>>
(x, y, ids, dst, ncols_x, cols_per_block, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
}
}
template <typename T, int cols_per_block>
template <typename T, int rows_per_block, int cols_per_block>
void mul_mat_f_cuda(
const T * x, const float * y, const int32_t * ids, float * dst,
const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst,
@ -605,7 +645,7 @@ void mul_mat_f_cuda(
int64_t nwarps_best = 1;
int64_t niter_best = (ncols_x + warp_size*2 - 1) / (warp_size*2);
int64_t max_block_size = 256;
int64_t max_block_size = mmf_get_max_block_size(cc);
for (int64_t nwarps = 2; nwarps <= max_block_size/warp_size; nwarps++) {
const int64_t niter = (ncols_x + nwarps*warp_size*2 - 1) / (nwarps*warp_size*2);
if (niter < niter_best) {
@ -614,10 +654,9 @@ void mul_mat_f_cuda(
}
}
constexpr int rows_per_block = MMF_ROWS_PER_BLOCK;
const int nbytes_shared_iter = nwarps_best * (volta_mma_available(cc) ? tile_A_32::I : tile_A_16::I) * (warp_size + 4) * 4;
const int nbytes_cols_per_block_pad = amd_wmma_available(cc) ? tile_B_16::I : tile_B_8::I;
const int nbytes_shared_combine = GGML_PAD(cols_per_block, nbytes_cols_per_block_pad) * (nwarps_best*rows_per_block + 4) * 4;
const int nbytes_shared_iter = nwarps_best * (volta_mma_available(cc) ? tile_A_32::I : tile_A_16::I) * (warp_size + mmf_get_padding(cc)) * 4;
const int nbytes_cols_per_block_pad = (amd_wmma_available(cc) || amd_mfma_available(cc)) ? tile_B_16::I : tile_B_8::I;
const int nbytes_shared_combine = GGML_PAD(cols_per_block, nbytes_cols_per_block_pad) * (nwarps_best*rows_per_block + mmf_get_padding(cc)) * 4;
const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine);
const int nbytes_slotmap = ids ? GGML_PAD(cols_per_block, 16) * sizeof(int) : 0;
const int nbytes_shared_total = nbytes_shared + nbytes_slotmap;
@ -628,56 +667,56 @@ void mul_mat_f_cuda(
switch (nwarps_best) {
case 1: {
mul_mat_f_switch_ids<T, cols_per_block, 1>(
mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 1>(
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
ids_data);
} break;
case 2: {
mul_mat_f_switch_ids<T, cols_per_block, 2>(
mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 2>(
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
ids_data);
} break;
case 3: {
mul_mat_f_switch_ids<T, cols_per_block, 3>(
mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 3>(
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
ids_data);
} break;
case 4: {
mul_mat_f_switch_ids<T, cols_per_block, 4>(
mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 4>(
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
ids_data);
} break;
case 5: {
mul_mat_f_switch_ids<T, cols_per_block, 5>(
mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 5>(
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
ids_data);
} break;
case 6: {
mul_mat_f_switch_ids<T, cols_per_block, 6>(
mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 6>(
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
ids_data);
} break;
case 7: {
mul_mat_f_switch_ids<T, cols_per_block, 7>(
mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 7>(
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
ids_data);
} break;
case 8: {
mul_mat_f_switch_ids<T, cols_per_block, 8>(
mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 8>(
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
@ -691,7 +730,7 @@ void mul_mat_f_cuda(
GGML_UNUSED_VARS(nchannels_y);
}
template <typename T>
template <typename T, int rows_per_block>
static void mul_mat_f_switch_cols_per_block(
const T * x, const float * y, const int32_t * ids, float * dst,
const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst,
@ -708,82 +747,82 @@ static void mul_mat_f_switch_cols_per_block(
switch (ncols_case) {
case 1: {
mul_mat_f_cuda<T, 1>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
mul_mat_f_cuda<T, rows_per_block, 1>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 2: {
mul_mat_f_cuda<T, 2>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
mul_mat_f_cuda<T, rows_per_block, 2>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 3: {
mul_mat_f_cuda<T, 3>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
mul_mat_f_cuda<T, rows_per_block, 3>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 4: {
mul_mat_f_cuda<T, 4>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
mul_mat_f_cuda<T, rows_per_block, 4>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 5: {
mul_mat_f_cuda<T, 5>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
mul_mat_f_cuda<T, rows_per_block, 5>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 6: {
mul_mat_f_cuda<T, 6>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
mul_mat_f_cuda<T, rows_per_block, 6>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 7: {
mul_mat_f_cuda<T, 7>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
mul_mat_f_cuda<T, rows_per_block, 7>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 8: {
mul_mat_f_cuda<T, 8>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
mul_mat_f_cuda<T, rows_per_block, 8>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 9: {
mul_mat_f_cuda<T, 9>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
mul_mat_f_cuda<T, rows_per_block, 9>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 10: {
mul_mat_f_cuda<T, 10>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
mul_mat_f_cuda<T, rows_per_block, 10>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 11: {
mul_mat_f_cuda<T, 11>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
mul_mat_f_cuda<T, rows_per_block, 11>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 12: {
mul_mat_f_cuda<T, 12>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
mul_mat_f_cuda<T, rows_per_block, 12>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 13: {
mul_mat_f_cuda<T, 13>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
mul_mat_f_cuda<T, rows_per_block, 13>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 14: {
mul_mat_f_cuda<T, 14>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
mul_mat_f_cuda<T, rows_per_block, 14>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 15: {
mul_mat_f_cuda<T, 15>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
mul_mat_f_cuda<T, rows_per_block, 15>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 16: {
mul_mat_f_cuda<T, 16>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
mul_mat_f_cuda<T, rows_per_block, 16>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
@ -793,8 +832,36 @@ static void mul_mat_f_switch_cols_per_block(
}
}
#define DECL_MMF_CASE_HELPER(T, ncols_dst) \
template void mul_mat_f_cuda<T, ncols_dst>( \
template <typename T>
static void mul_mat_f_switch_rows_per_block(
const int rows_per_block, const T * x, const float * y, const int32_t * ids, float * dst,
const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst,
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
const int64_t stride_col_id, const int stride_row_id,
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
cudaStream_t stream, const mmf_ids_data * ids_data) {
switch (rows_per_block) {
case MMF_ROWS_PER_BLOCK: {
mul_mat_f_switch_cols_per_block<T, MMF_ROWS_PER_BLOCK>(
x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case MMF_ROWS_PER_BLOCK_CDNA: {
mul_mat_f_switch_cols_per_block<T, MMF_ROWS_PER_BLOCK_CDNA>(
x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
default:
GGML_ABORT("unsupported rows_per_block: %i", rows_per_block);
}
}
#define DECL_MMF_CASE_HELPER(T, nrows_dst, ncols_dst) \
template void mul_mat_f_cuda<T, nrows_dst, ncols_dst>( \
const T * x, const float * y, const int32_t * ids, float * dst, \
const int64_t ncols_x, const int64_t nrows_x, int64_t ncols_dst_total, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, \
const int64_t stride_col_id, const int64_t stride_row_id, \
@ -803,16 +870,22 @@ static void mul_mat_f_switch_cols_per_block(
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, \
cudaStream_t stream, const mmf_ids_data * ids_data);
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
#if !defined(GGML_USE_MUSA)
#define DECL_MMF_CASE_EXTERN(ncols_dst) \
extern DECL_MMF_CASE_HELPER(float, ncols_dst) \
extern DECL_MMF_CASE_HELPER(half2, ncols_dst) \
extern DECL_MMF_CASE_HELPER(nv_bfloat162, ncols_dst)
extern DECL_MMF_CASE_HELPER(float, MMF_ROWS_PER_BLOCK, ncols_dst) \
extern DECL_MMF_CASE_HELPER(half2, MMF_ROWS_PER_BLOCK, ncols_dst) \
extern DECL_MMF_CASE_HELPER(nv_bfloat162, MMF_ROWS_PER_BLOCK, ncols_dst) \
extern DECL_MMF_CASE_HELPER(float, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst) \
extern DECL_MMF_CASE_HELPER(half2, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst) \
extern DECL_MMF_CASE_HELPER(nv_bfloat162, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst)
#define DECL_MMF_CASE(ncols_dst) \
DECL_MMF_CASE_HELPER(float, ncols_dst) \
DECL_MMF_CASE_HELPER(half2, ncols_dst) \
DECL_MMF_CASE_HELPER(nv_bfloat162, ncols_dst)
DECL_MMF_CASE_HELPER(float, MMF_ROWS_PER_BLOCK, ncols_dst) \
DECL_MMF_CASE_HELPER(half2, MMF_ROWS_PER_BLOCK, ncols_dst) \
DECL_MMF_CASE_HELPER(nv_bfloat162, MMF_ROWS_PER_BLOCK, ncols_dst) \
DECL_MMF_CASE_HELPER(float, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst) \
DECL_MMF_CASE_HELPER(half2, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst) \
DECL_MMF_CASE_HELPER(nv_bfloat162, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst)
DECL_MMF_CASE_EXTERN(1);
DECL_MMF_CASE_EXTERN(2);

View file

@ -5,6 +5,13 @@
#include <cmath>
#include <initializer_list>
// Kernel config struct - passed by value to CUDA kernel
struct topk_moe_config {
bool use_sigmoid;
bool with_norm;
bool delayed_softmax;
};
// Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path.
template <int experts_per_thread, bool use_limit>
__device__ void softmax_warp_inplace(float (&vals)[experts_per_thread], const int limit, const int lane) {
@ -50,6 +57,16 @@ __device__ void softmax_warp_inplace(float (&vals)[experts_per_thread], const in
}
}
template <int experts_per_thread, bool use_limit>
__device__ void sigmoid_warp_inplace(float (&vals)[experts_per_thread], const int limit, const int lane) {
#pragma unroll
for (int i = 0; i < experts_per_thread; i++) {
const int idx = lane + i * WARP_SIZE;
const bool active = !use_limit || (idx < limit);
vals[i] = active ? 1.f / (1.f + expf(-vals[i])) : -INFINITY;
}
}
/*
This kernel does the following:
1. optionally softmax over the logits per token [n_experts, n_tokens]
@ -59,13 +76,16 @@ __device__ void softmax_warp_inplace(float (&vals)[experts_per_thread], const in
It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models
*/
template <int n_experts, bool with_norm, bool delayed_softmax = false>
__launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * logits,
float * weights,
int32_t * ids,
const int n_rows,
const int n_expert_used,
const float clamp_val) {
template <int n_experts, bool has_bias>
__launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * logits,
float * weights,
int32_t * ids,
float * bias,
const int n_rows,
const int n_expert_used,
const float clamp_val,
const float scale_val,
const topk_moe_config config) {
const int row = blockIdx.x * blockDim.y + threadIdx.y;
if (row >= n_rows) {
return;
@ -79,14 +99,41 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
float wt[experts_per_thread];
// Initialize all slots to -INFINITY
#pragma unroll
for (int i = 0; i < experts_per_thread; i++) {
wt[i] = -INFINITY;
}
#pragma unroll
for (int i = 0; i < n_experts; i += WARP_SIZE) {
const int expert = i + threadIdx.x;
wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[expert] : -INFINITY;
}
if constexpr (!delayed_softmax) {
softmax_warp_inplace<experts_per_thread, false>(wt, n_experts, threadIdx.x);
if (!config.delayed_softmax) {
if (config.use_sigmoid) {
sigmoid_warp_inplace<experts_per_thread, false>(wt, n_experts, threadIdx.x);
} else {
softmax_warp_inplace<experts_per_thread, false>(wt, n_experts, threadIdx.x);
}
}
// selection_wt is only needed when bias is present (selection uses wt + bias)
// when no bias, we use wt directly for both selection and weight values
float selection_wt[has_bias ? experts_per_thread : 1];
if constexpr (has_bias) {
#pragma unroll
for (int i = 0; i < experts_per_thread; i++) {
selection_wt[i] = -INFINITY;
}
#pragma unroll
for (int i = 0; i < n_experts; i += WARP_SIZE) {
const int expert = i + threadIdx.x;
selection_wt[i / WARP_SIZE] =
(n_experts % WARP_SIZE == 0 || expert < n_experts) ? wt[i / WARP_SIZE] + bias[expert] : -INFINITY;
}
}
//at this point, each thread holds either a portion of the softmax distribution
@ -106,22 +153,56 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
float max_val = wt[0];
int max_expert = threadIdx.x;
#pragma unroll
for (int i = 1; i < experts_per_thread; i++) {
const int expert = threadIdx.x + i * WARP_SIZE;
if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && wt[i] > max_val) {
max_val = wt[i];
max_expert = expert;
}
}
if constexpr (has_bias) {
float max_val_s = selection_wt[0];
#pragma unroll
for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2) {
const float val = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, WARP_SIZE);
const int expert = __shfl_xor_sync(0xFFFFFFFF, max_expert, mask, WARP_SIZE);
if (val > max_val || (val == max_val && expert < max_expert)) {
max_val = val;
max_expert = expert;
for (int i = 1; i < experts_per_thread; i++) {
const int expert = threadIdx.x + i * WARP_SIZE;
if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && selection_wt[i] > max_val_s) {
max_val = wt[i];
max_val_s = selection_wt[i];
max_expert = expert;
}
}
#pragma unroll
for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2) {
const float val = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, WARP_SIZE);
const float val_s = __shfl_xor_sync(0xFFFFFFFF, max_val_s, mask, WARP_SIZE);
const int expert = __shfl_xor_sync(0xFFFFFFFF, max_expert, mask, WARP_SIZE);
if (val_s > max_val_s || (val_s == max_val_s && expert < max_expert)) {
max_val = val;
max_val_s = val_s;
max_expert = expert;
}
}
if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) {
selection_wt[max_expert / WARP_SIZE] = -INFINITY;
}
} else {
#pragma unroll
for (int i = 1; i < experts_per_thread; i++) {
const int expert = threadIdx.x + i * WARP_SIZE;
if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && wt[i] > max_val) {
max_val = wt[i];
max_expert = expert;
}
}
#pragma unroll
for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2) {
const float val = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, WARP_SIZE);
const int expert = __shfl_xor_sync(0xFFFFFFFF, max_expert, mask, WARP_SIZE);
if (val > max_val || (val == max_val && expert < max_expert)) {
max_val = val;
max_expert = expert;
}
}
if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) {
wt[max_expert / WARP_SIZE] = -INFINITY;
}
}
@ -130,16 +211,14 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
}
if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) {
wt[max_expert / WARP_SIZE] = -INFINITY;
ids[k] = max_expert;
if constexpr (with_norm) {
if (config.with_norm) {
wt_sum += max_val;
}
}
}
if constexpr (with_norm) {
if (config.with_norm) {
wt_sum = warp_reduce_sum(wt_sum);
wt_sum = max(wt_sum, clamp_val);
const float inv_sum = 1.0f / wt_sum;
@ -149,7 +228,7 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
}
}
if constexpr (delayed_softmax) {
if (config.delayed_softmax) {
softmax_warp_inplace<experts_per_thread, true>(output_weights, n_expert_used, threadIdx.x);
}
@ -157,25 +236,25 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
for (int i = 0; i < experts_per_thread; i++) {
const int idx = i * WARP_SIZE + threadIdx.x;
if (idx < n_expert_used) {
weights[idx] = output_weights[i];
weights[idx] = output_weights[i] * scale_val;
}
}
if (!with_norm) {
GGML_UNUSED(clamp_val);
}
}
template <bool with_norm, bool delayed_softmax = false>
template<bool has_bias>
static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
const float * logits,
float * weights,
int32_t * ids,
float * bias,
const int n_rows,
const int n_expert,
const int n_expert_used,
const float clamp_val) {
static_assert(!(with_norm && delayed_softmax), "delayed softmax is not supported with weight normalization");
const float clamp_val,
const float scale_val,
const topk_moe_config config) {
GGML_ASSERT(!(config.with_norm && config.delayed_softmax) &&
"delayed softmax is not supported with weight normalization");
const int rows_per_block = 4;
dim3 grid_dims((n_rows + rows_per_block - 1) / rows_per_block, 1, 1);
dim3 block_dims(WARP_SIZE, rows_per_block, 1);
@ -183,44 +262,48 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
switch (n_expert) {
case 1:
topk_moe_cuda<1, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
topk_moe_cuda<1, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
clamp_val, scale_val, config);
break;
case 2:
topk_moe_cuda<2, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
topk_moe_cuda<2, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
clamp_val, scale_val, config);
break;
case 4:
topk_moe_cuda<4, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
topk_moe_cuda<4, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
clamp_val, scale_val, config);
break;
case 8:
topk_moe_cuda<8, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
topk_moe_cuda<8, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
clamp_val, scale_val, config);
break;
case 16:
topk_moe_cuda<16, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
topk_moe_cuda<16, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
clamp_val, scale_val, config);
break;
case 32:
topk_moe_cuda<32, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
topk_moe_cuda<32, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
clamp_val, scale_val, config);
break;
case 64:
topk_moe_cuda<64, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
topk_moe_cuda<64, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
clamp_val, scale_val, config);
break;
case 128:
topk_moe_cuda<128, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
topk_moe_cuda<128, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
clamp_val, scale_val, config);
break;
case 256:
topk_moe_cuda<256, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
topk_moe_cuda<256, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
clamp_val, scale_val, config);
break;
case 512:
topk_moe_cuda<512, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
topk_moe_cuda<512, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
clamp_val, scale_val, config);
break;
case 576:
topk_moe_cuda<576, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
clamp_val, scale_val, config);
break;
default:
GGML_ASSERT(false && "fatal error");
@ -228,13 +311,14 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
}
}
void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
const ggml_tensor * logits,
ggml_tensor * weights,
ggml_tensor * ids,
const bool with_norm,
const bool delayed_softmax,
ggml_tensor * clamp) {
void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
const ggml_tensor * logits,
ggml_tensor * weights,
ggml_tensor * ids,
const ggml_tensor * clamp,
const ggml_tensor * scale,
const ggml_tensor * bias,
const ggml_cuda_topk_moe_args & args) {
GGML_ASSERT(logits->type == GGML_TYPE_F32);
GGML_ASSERT(weights->type == GGML_TYPE_F32);
GGML_ASSERT(ids->type == GGML_TYPE_I32);
@ -245,107 +329,75 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
const float * logits_d = (const float *) logits->data;
float * weights_d = (float *) weights->data;
int32_t * ids_d = (int32_t *) ids->data;
float * bias_d = bias ? (float *) bias->data : nullptr;
float scale_val = scale ? ggml_get_op_params_f32(scale, 0) : 1.0f;
GGML_ASSERT(ids->nb[1] / ggml_type_size(ids->type) == (size_t) n_experts);
const int n_expert_used = weights->ne[1];
const bool with_norm = clamp != nullptr;
float clamp_val = -INFINITY;
if (with_norm) {
if (clamp) {
clamp_val = ggml_get_op_params_f32(clamp, 0);
}
launch_topk_moe_cuda<true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used, clamp_val);
if (clamp) {
clamp_val = ggml_get_op_params_f32(clamp, 0);
}
topk_moe_config config;
config.use_sigmoid = args.sigmoid;
config.with_norm = with_norm;
config.delayed_softmax = args.delayed_softmax;
if (bias) {
launch_topk_moe_cuda<true>(ctx, logits_d, weights_d, ids_d, bias_d, n_rows, n_experts, n_expert_used, clamp_val,
scale_val, config);
} else {
GGML_ASSERT(clamp == nullptr);
if (delayed_softmax) {
launch_topk_moe_cuda<false, true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used,
clamp_val);
} else {
launch_topk_moe_cuda<false, false>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used,
clamp_val);
}
launch_topk_moe_cuda<false>(ctx, logits_d, weights_d, ids_d, bias_d, n_rows, n_experts, n_expert_used, clamp_val,
scale_val, config);
}
}
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax,
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * gating_op,
const ggml_tensor * weights,
const ggml_tensor * get_rows,
const ggml_tensor * argsort,
const ggml_tensor * clamp,
int n_expert) {
ggml_tensor * probs = get_rows->src[0];
if (probs->op != GGML_OP_RESHAPE) {
return false;
}
probs = probs->src[0];
ggml_tensor * selection_probs = argsort->src[0];
if (probs != selection_probs) {
const ggml_tensor * logits,
const ggml_tensor * ids) {
const int n_expert = ids->nb[1] / ids->nb[0];
if (((n_expert & (n_expert - 1)) != 0 || n_expert > 512) && n_expert != 576) {
return false;
}
float scale = 1.0f;
float max_bias = 0.0f;
memcpy(&scale, (const float *) softmax->op_params + 0, sizeof(float));
memcpy(&max_bias, (const float *) softmax->op_params + 1, sizeof(float));
if (!ggml_is_contiguous(softmax->src[0]) || !ggml_is_contiguous(weights)) {
if (!ggml_is_contiguous(weights) || !ggml_is_contiguous(logits)) {
return false;
}
if (scale != 1.0f || max_bias != 0.0f) {
return false;
}
if (gating_op->op == GGML_OP_SOFT_MAX) {
const ggml_tensor * softmax = gating_op;
float scale = 1.0f;
float max_bias = 0.0f;
// don't fuse when masks or sinks are present
if (softmax->src[1] || softmax->src[2]) {
return false;
}
memcpy(&scale, (const float *) softmax->op_params + 0, sizeof(float));
memcpy(&max_bias, (const float *) softmax->op_params + 1, sizeof(float));
// n_expert must be a power of 2
if ((n_expert & (n_expert - 1)) != 0 || n_expert > 512) {
return false;
}
if (clamp) {
if (clamp->op != GGML_OP_CLAMP) {
if (!ggml_is_contiguous(softmax->src[0])) {
return false;
}
float max_val = ggml_get_op_params_f32(clamp, 1);
if (max_val != INFINITY) {
if (scale != 1.0f || max_bias != 0.0f) {
return false;
}
// don't fuse when masks or sinks are present
if (softmax->src[1] || softmax->src[2]) {
return false;
}
} else if (gating_op->op == GGML_OP_UNARY) {
ggml_unary_op op = ggml_get_unary_op(gating_op);
if (op != GGML_UNARY_OP_SIGMOID) {
return false;
}
}
return true;
}
std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool norm, bool delayed_softmax) {
static std::initializer_list<enum ggml_op> norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
GGML_OP_SUM_ROWS, GGML_OP_CLAMP, GGML_OP_DIV,
GGML_OP_RESHAPE };
static std::initializer_list<enum ggml_op> no_norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
GGML_OP_VIEW, GGML_OP_GET_ROWS };
static std::initializer_list<enum ggml_op> delayed_softmax_ops = { GGML_OP_ARGSORT, GGML_OP_VIEW,
GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
GGML_OP_SOFT_MAX, GGML_OP_RESHAPE };
GGML_ASSERT(!norm || !delayed_softmax);
if (delayed_softmax) {
return delayed_softmax_ops;
}
if (norm) {
return norm_ops;
}
return no_norm_ops;
}

View file

@ -3,19 +3,25 @@
#include <initializer_list>
void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
const ggml_tensor * logits,
ggml_tensor * weights,
ggml_tensor * ids,
const bool with_norm,
const bool delayed_softmax = false,
ggml_tensor * weight_clamp = nullptr);
struct ggml_cuda_topk_moe_args {
bool sigmoid{};
bool softmax{};
bool delayed_softmax{};
bool prob_bias{};
bool norm{};
bool scale{};
};
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax,
void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
const ggml_tensor * logits,
ggml_tensor * weights,
ggml_tensor * ids,
const ggml_tensor * clamp,
const ggml_tensor * scale,
const ggml_tensor * bias,
const ggml_cuda_topk_moe_args & args);
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * gating_op,
const ggml_tensor * weights,
const ggml_tensor * get_rows,
const ggml_tensor * argsort,
const ggml_tensor * clamp,
int n_expert);
std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool with_norm, bool delayed_softmax = false);
const ggml_tensor * logits,
const ggml_tensor * ids);

View file

@ -3178,17 +3178,31 @@ static void ggml_vk_load_shaders(vk_device& device) {
// For scalar, use 128 (arbitrary)
// The same D_split value is used for both HSK and HSV, so just base it on the union of the LSBs.
const uint32_t D = (hsk|hsv);
uint32_t wg_size = (path == FA_SCALAR || path == FA_COOPMAT1)
? scalar_flash_attention_workgroup_size
: ((small_rows && (D % 32) == 0) ? 256 : 128);
auto rows_cols = fa_rows_cols(path, hsk, hsv, clamp, type, small_rows, small_cache);
uint32_t wg_size;
switch (path) {
case FA_COOPMAT2:
wg_size = ((small_rows && (D % 32) == 0) ? 256 : 128);
break;
case FA_COOPMAT1:
wg_size = (rows_cols[1] / 16) * device->subgroup_size; // enough subgroups for Bc/MatBc
break;
default:
wg_size = scalar_flash_attention_workgroup_size;
break;
}
// D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it.
// D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader.
const uint32_t D_lsb = D ^ (D & (D-1));
uint32_t D_split = std::min(std::min(device->subgroup_size, 8u), D_lsb / 4);
return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split};
// Nvidia prefers shared memory use to load large tiles of K
// AMD prefers loading K directly from global memory
const uint32_t k_load_shmem = device->vendor_id == VK_VENDOR_ID_NVIDIA ? 1 : 0;
return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split, device->subgroup_size, k_load_shmem};
};
#define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \
@ -3203,15 +3217,15 @@ static void ggml_vk_load_shaders(vk_device& device) {
if (path == FAPATH) { \
if (aligned) { \
if (f32acc) { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
} else { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
} \
} else { \
if (f32acc) { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
} else { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
} \
} \
} \
@ -8382,41 +8396,49 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con
const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf;
const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", total_size=" << total_size << ", supported=" << supported);
VK_LOG_DEBUG("ggml_vk_flash_attn_scalar_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", total_size=" << total_size << ", supported=" << supported);
return supported;
}
static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool f32acc) {
static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type kv_type) {
// Needs to be kept up to date on shader changes
GGML_UNUSED(hsv);
const uint32_t wg_size = scalar_flash_attention_workgroup_size;
const uint32_t Br = coopmat1_flash_attention_num_large_rows;
const uint32_t Bc = scalar_flash_attention_Bc;
const auto rows_cols = fa_rows_cols(FA_COOPMAT1, hsk, hsv, 0, kv_type, false, false);
const uint32_t Br = rows_cols[0];
const uint32_t Bc = rows_cols[1];
const uint32_t MatBr = 16, MatBc = 16;
const uint32_t row_split = Bc / MatBc;
const uint32_t hsk_pad = ROUNDUP_POW2(hsk, 16);
const uint32_t acctype = f32acc ? 4 : 2;
const uint32_t f16vec4 = 8;
const uint32_t tmpsh = wg_size * sizeof(float);
const uint32_t tmpshv4 = wg_size * 4 * acctype;
const uint32_t tmpsh = (Bc / MatBc) * sizeof(float);
const uint32_t qstride = hsk_pad / 4 + 2;
const uint32_t Qf = Br * qstride * f16vec4;
const uint32_t psh_stride = Br / 4 + 2;
const uint32_t Psh = Bc * psh_stride * f16vec4;
const uint32_t sfshstride = (hsk <= 128) ? (Br + 8) : Br;
const uint32_t sfsh = Bc * sfshstride * acctype;
const uint32_t kshstride = hsk_pad / 4 + 2;
const uint32_t ksh = Bc * kshstride * f16vec4;
const bool k_load_shmem = device->vendor_id == VK_VENDOR_ID_NVIDIA;
const uint32_t kshstride = (k_load_shmem ? hsk_pad : MatBr) / 4 + 2;
const uint32_t vsh_stride = MatBc / 4 * row_split;
const uint32_t ksh = ((kshstride >= vsh_stride) ? (Bc * kshstride) : (Bc * vsh_stride)) * f16vec4;
const uint32_t slope = Br * sizeof(float);
const uint32_t slope = Br * acctype;
const uint32_t total_size = tmpsh + tmpshv4 + Qf + sfsh + ksh + slope;
const uint32_t total_size = tmpsh + Qf + Psh + sfsh + ksh + slope;
const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", f32acc=" << f32acc << ", total_size=" << total_size << ", supported=" << supported);
VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", f32acc=" << f32acc << ", kv_type=" << kv_type << ", total_size=" << total_size << ", supported=" << supported);
return supported;
}
@ -8480,7 +8502,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
const bool coopmat_shape_supported = (dst->op_params[3] == GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f32acc) ||
(dst->op_params[3] != GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f16acc);
const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(ctx->device, HSK, HSV, dst->op_params[3] == GGML_PREC_F32);
const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(ctx->device, HSK, HSV, dst->op_params[3] == GGML_PREC_F32, k->type);
if (!coopmat_shape_supported || !coopmat_shmem_supported) {
path = FA_SCALAR;

View file

@ -8,6 +8,8 @@ layout (constant_id = 3) const uint32_t HSK = 32;
layout (constant_id = 4) const uint32_t HSV = 32;
layout (constant_id = 5) const uint32_t Clamp = 0;
layout (constant_id = 6) const uint32_t D_split = 16;
layout (constant_id = 7) const uint32_t SubGroupSize = 32;
layout (constant_id = 8) const uint32_t K_LOAD_SHMEM = 0;
// Round up head sizes to a multiple of 16, for coopmat1/coopmat2 paths
const uint32_t HSK_pad = (HSK + 15) & ~15;
@ -74,6 +76,10 @@ layout (binding = 1) readonly buffer K_PACKED16 {A_TYPE_PACKED16 k_data_packed16
layout (binding = 2) readonly buffer V_PACKED16 {A_TYPE_PACKED16 v_data_packed16[];} v_packed;
#endif
#ifndef BLOCK_SIZE
#define BLOCK_SIZE 1
#endif
#if defined(DATA_A_F32)
#undef BLOCK_SIZE
#define BLOCK_SIZE 4

View file

@ -7,6 +7,7 @@
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#extension GL_KHR_shader_subgroup_basic : enable
#extension GL_KHR_shader_subgroup_arithmetic : enable
#extension GL_KHR_shader_subgroup_vote : enable
#extension GL_KHR_memory_scope_semantics : enable
#extension GL_KHR_cooperative_matrix : enable
@ -14,12 +15,13 @@
#include "types.glsl"
#include "flash_attn_base.glsl"
const uint32_t HSK_per_thread = HSK / D_split;
const uint32_t HSV_per_thread = HSV / D_split;
// These need to be supported N,M values for a MatBc x MatBr x 16 coopmatmuladd
const uint32_t MatBr = 16;
const uint32_t MatBc = 16;
const uint32_t row_split = 4;
const uint32_t row_split = Bc / MatBc;
const uint32_t rows_per_thread = Br / row_split;
const uint32_t cols_per_iter = gl_WorkGroupSize.x / D_split / row_split;
const uint32_t cols_per_iter = gl_WorkGroupSize.x / row_split;
const uint32_t cols_per_thread = Bc / cols_per_iter;
@ -40,24 +42,24 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TY
return elem;
}
// These need to be supported N,M values for a MatBc x MatBr x 16 coopmatmuladd
const uint32_t MatBr = 16;
const uint32_t MatBc = 16;
shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x];
shared ACC_TYPEV4 tmpshv4[gl_WorkGroupSize.x];
shared float tmpsh[row_split];
const uint32_t qstride = HSK_pad / 4 + 2; // in units of f16vec4
shared f16vec4 Qf[Br * qstride];
const uint psh_stride = Br / 4 + 2;
shared f16vec4 Psh[Bc * psh_stride];
// Avoid padding for hsk==256 to make it fit in 48KB shmem.
const uint32_t sfshstride = (HSK <= 128) ? (Br + 8) : Br;
shared ACC_TYPE sfsh[Bc * sfshstride];
const uint32_t sfshstride = (HSK <= 128) ? (Br / 4 + 2) : Br / 4;
shared ACC_TYPEV4 sfsh[Bc * sfshstride];
const uint32_t kshstride = HSK_pad / 4 + 2; // in units of f16vec4
shared f16vec4 ksh[Bc * kshstride];
const uint32_t kshstride = (K_LOAD_SHMEM != 0 ? HSK_pad : MatBr) / 4 + 2; // in units of f16vec4
const uint v_cols = MatBc / 4 * row_split; // total cols, 4 vec4s per MatBc * number of subgroups
const uint vsh_stride = v_cols;
shared f16vec4 ksh[(kshstride >= vsh_stride) ? (Bc * kshstride) : (Bc * vsh_stride)];
shared float slope[Br];
shared ACC_TYPE slope[Br];
void main() {
#ifdef NEEDS_INIT_IQ_SHMEM
@ -69,9 +71,9 @@ void main() {
const uint32_t tid = gl_LocalInvocationIndex;
const uint32_t threads_per_rowgroup = gl_WorkGroupSize.x / row_split;
const uint32_t d_per_thread = (HSV/4 + threads_per_rowgroup - 1) / threads_per_rowgroup;
const uint32_t row_tid = gl_LocalInvocationIndex / threads_per_rowgroup;
const uint32_t d_tid = gl_LocalInvocationIndex % D_split;
const uint32_t col_tid = (gl_LocalInvocationIndex % threads_per_rowgroup) / D_split;
const uint32_t col_tid = gl_LocalInvocationIndex % threads_per_rowgroup;
#define tile_row(r) (row_tid * rows_per_thread + (r))
@ -102,9 +104,9 @@ void main() {
}
barrier();
ACC_TYPEV4 Of[rows_per_thread][HSV_per_thread / 4];
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
ACC_TYPEV4 Of[rows_per_thread][d_per_thread];
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
[[unroll]] for (uint32_t d = 0; d < d_per_thread; ++d) {
Of[r][d] = ACC_TYPEV4(0.0);
}
}
@ -125,13 +127,11 @@ void main() {
uint r = tid;
slope[r] = perElemOpComputeSlope(r, col_tid, ACC_TYPE(0), iq2);
}
barrier();
} else {
if (tid < Br) {
uint r = tid;
slope[r] = 1.0;
slope[r] = ACC_TYPE(1.0);
}
barrier();
}
#if BLOCK_SIZE > 1
@ -149,19 +149,45 @@ void main() {
[[dont_unroll]]
for (uint32_t j = start_j; j < end_j; ++j) {
float mask_cache[Bc * Br / WorkGroupSize];
f16vec4 mask_cache[Bc * Br / 4 / WorkGroupSize];
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
float max_mask = NEG_FLT_MAX_OVER_2;
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
uint32_t c = (idx + tid) % Bc;
uint32_t r = (idx + tid) / Bc;
if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
float m = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br / 4; idx += gl_WorkGroupSize.x) {
uint32_t c = (idx + tid) / (Br / 4);
uint32_t r = (idx + tid) % (Br / 4);
if (idx + tid < Bc * Br / 4 || idx + gl_WorkGroupSize.x <= Bc * Br / 4) {
if ((!KV_bounds_check || j * Bc + c < KV)) {
f16vec4 m;
if (!nem1_bounds_check || i * Br + r * 4 + 3 < p.nem1) {
m = f16vec4(data_m[m_offset + (i * Br + r * 4 ) * m_stride + (j * Bc + c)],
data_m[m_offset + (i * Br + r * 4 + 1) * m_stride + (j * Bc + c)],
data_m[m_offset + (i * Br + r * 4 + 2) * m_stride + (j * Bc + c)],
data_m[m_offset + (i * Br + r * 4 + 3) * m_stride + (j * Bc + c)]);
max_mask = max(max(max(max(max_mask, float(m[0])), float(m[1])), float(m[2])), float(m[3]));
} else if (i * Br + r * 4 + 2 < p.nem1) {
m = f16vec4(data_m[m_offset + (i * Br + r * 4 ) * m_stride + (j * Bc + c)],
data_m[m_offset + (i * Br + r * 4 + 1) * m_stride + (j * Bc + c)],
data_m[m_offset + (i * Br + r * 4 + 2) * m_stride + (j * Bc + c)],
0.0);
max_mask = max(max(max(max_mask, float(m[0])), float(m[1])), float(m[2]));
} else if (i * Br + r * 4 + 1 < p.nem1) {
m = f16vec4(data_m[m_offset + (i * Br + r * 4 ) * m_stride + (j * Bc + c)],
data_m[m_offset + (i * Br + r * 4 + 1) * m_stride + (j * Bc + c)],
0.0,
0.0);
max_mask = max(max(max_mask, float(m[0])), float(m[1]));
} else if (i * Br + r * 4 < p.nem1) {
m = f16vec4(data_m[m_offset + (i * Br + r * 4 ) * m_stride + (j * Bc + c)],
0.0,
0.0,
0.0);
max_mask = max(max_mask, float(m[0]));
} else {
m = f16vec4(0.0);
}
mask_cache[idx / WorkGroupSize] = m;
max_mask = max(max_mask, m);
}
}
}
@ -180,26 +206,28 @@ void main() {
}
}
[[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) {
uint32_t d = (idx + tid) % (HSK / 4);
uint32_t c = (idx + tid) / (HSK / 4);
if (c < Bc && d < HSK / 4) {
f16vec4 K_Tf = f16vec4(0);
if (!KV_bounds_check || j * Bc + c < KV) {
if (K_LOAD_SHMEM != 0) {
[[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) {
uint32_t d = (idx + tid) % (HSK / 4);
uint32_t c = (idx + tid) / (HSK / 4);
if (c < Bc && d < HSK / 4) {
f16vec4 K_Tf = f16vec4(0);
if (!KV_bounds_check || j * Bc + c < KV) {
#if BLOCK_SIZE > 1
uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;
uint ib = coord / BLOCK_SIZE;
uint iqs = (coord % BLOCK_SIZE);
K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K));
uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;
uint ib = coord / BLOCK_SIZE;
uint iqs = (coord % BLOCK_SIZE);
K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K));
#else
K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]);
K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]);
#endif
}
}
ksh[c * kshstride + d] = K_Tf;
ksh[c * kshstride + d] = K_Tf;
}
}
barrier();
}
barrier();
// K * Q^T -> S^T: Bc x HSK_pad * HSK_pad x Br -> Bc x Br
// Bc split across workgroup (four subgroups), loop over HSK in chunks of 16: 16 x 16 * 16 x 16 -> 16 x 16
@ -208,11 +236,55 @@ void main() {
coopmat<float16_t, gl_ScopeSubgroup, MatBc, 16, gl_MatrixUseA> KMat;
coopmat<float16_t, gl_ScopeSubgroup, 16, MatBr, gl_MatrixUseB> QMat;
for (uint32_t d = 0; d < HSK_pad / 16; ++d) {
coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor);
[[unroll]] for (uint32_t d = 0; d < HSK_pad / 16; ++d) {
if (K_LOAD_SHMEM == 0) {
#if BLOCK_SIZE == 1
if (KV_bounds_check || d * 16 + 16 > HSK) {
#endif
barrier();
[[unroll]] for (uint32_t idx = 0; idx < Bc * MatBr / 4; idx += gl_WorkGroupSize.x) {
uint32_t col_vec = (idx + tid) % (MatBr / 4);
uint32_t row = (idx + tid) / (MatBr / 4);
if (idx + tid < Bc * MatBr / 4) {
f16vec4 K_Tf = f16vec4(0);
if ((!KV_bounds_check || j * Bc + row < KV) && (HSK == HSK_pad || d * 16 + col_vec * 4 < HSK)) {
#if BLOCK_SIZE > 1
uint coord = (j * Bc + row) * k_stride * BLOCK_SIZE + d * 16 + col_vec * 4;
uint ib = coord / BLOCK_SIZE;
uint iqs = (coord % BLOCK_SIZE);
K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K));
#else
K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + row) * k_stride / 4 + d * 16 / 4 + col_vec]);
#endif
}
uint coord = (gl_SubgroupID * MatBc) * kshstride + d * 16 / 4;
coopMatLoad(KMat, ksh, coord, kshstride, gl_CooperativeMatrixLayoutRowMajor);
ksh[row * kshstride + col_vec] = K_Tf;
}
}
barrier();
#if BLOCK_SIZE == 1
}
#endif
#if BLOCK_SIZE == 1
if (KV_bounds_check || d * 16 + 16 > HSK)
#endif
{
uint coord = (gl_SubgroupID * MatBc) * kshstride;
coopMatLoad(KMat, ksh, coord, kshstride, gl_CooperativeMatrixLayoutRowMajor);
}
#if BLOCK_SIZE == 1
else {
const uint coord = k_offset / 4 + (j * Bc + gl_SubgroupID * MatBc) * k_stride / 4 + d * 16 / 4;
coopMatLoad(KMat, data_kv4, coord, k_stride / 4, gl_CooperativeMatrixLayoutRowMajor);
}
#endif
} else {
uint coord = (gl_SubgroupID * MatBc) * kshstride + d * 16 / 4;
coopMatLoad(KMat, ksh, coord, kshstride, gl_CooperativeMatrixLayoutRowMajor);
}
coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor);
SfMat = coopMatMulAdd(KMat, QMat, SfMat);
}
@ -222,26 +294,26 @@ void main() {
barrier();
if (p.logit_softcap != 0.0f) {
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
uint32_t c = (idx + tid) / Br;
uint32_t r = (idx + tid) % Br;
if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
sfsh[c * sfshstride + r] = ACC_TYPE(p.logit_softcap * tanh(sfsh[c * sfshstride + r]));
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br / 4; idx += gl_WorkGroupSize.x) {
uint32_t c = (idx + tid) / (Br / 4);
uint32_t r = (idx + tid) % (Br / 4);
if (idx + tid < Bc * Br / 4 || idx + gl_WorkGroupSize.x <= Bc * Br / 4) {
sfsh[c * sfshstride + r] = ACC_TYPEV4(p.logit_softcap * tanh(sfsh[c * sfshstride + r]));
}
}
barrier();
}
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
uint32_t c = (idx + tid) % Bc;
uint32_t r = (idx + tid) / Bc;
if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
float f = mask_cache[idx / WorkGroupSize];
sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * f);
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br / 4; idx += gl_WorkGroupSize.x) {
uint32_t c = (idx + tid) / (Br / 4);
uint32_t r = (idx + tid) % (Br / 4);
if (idx + tid < Bc * Br / 4 || idx + gl_WorkGroupSize.x <= Bc * Br / 4) {
if (!KV_bounds_check || j * Bc + c < KV) {
// Mask nem1 bounds check is handled when loading masks
ACC_TYPEV4 masks = ACC_TYPEV4(mask_cache[idx / WorkGroupSize]);
ACC_TYPEV4 slopes = ACC_TYPEV4(slope[r * 4], slope[r * 4 + 1], slope[r * 4 + 2], slope[r * 4 + 3]);
sfsh[c * sfshstride + r] += slopes * masks;
}
}
}
@ -250,51 +322,145 @@ void main() {
float eMf[rows_per_thread];
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
const uint r_vec = tile_row(r) / 4;
const uint r_comp = tile_row(r) % 4;
float rowmaxf = NEG_FLT_MAX_OVER_2;
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
continue;
}
rowmaxf = max(rowmaxf, float(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride]));
rowmaxf = max(rowmaxf, float(sfsh[r_vec + (c * cols_per_iter + col_tid) * sfshstride][r_comp]));
}
float Moldf = Mf[r];
// Compute max across the row
rowmaxf = subgroupMax(rowmaxf);
// M = max(rowmax, Mold)
// P = e^(S - M)
// eM = e^(Mold - M)
Mf[r] = max(rowmaxf, Moldf);
eMf[r] = exp(Moldf - Mf[r]);
}
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Of[r][d] = ACC_TYPE(eMf[r]) * Of[r][d];
}
}
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Lf[r] = eMf[r]*Lf[r];
}
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
continue;
}
float Pf[rows_per_thread];
[[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
const uint d_local = d0 / threads_per_rowgroup;
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Pf[r] = exp(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride] - Mf[r]);
Lf[r] += Pf[r];
Of[r][d_local] = ACC_TYPE(eMf[r]) * Of[r][d_local];
}
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
#if BLOCK_SIZE > 1
uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
uint ib = coord / BLOCK_SIZE;
uint iqs = (coord % BLOCK_SIZE);
vec4 Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
#else
vec4 Vf = vec4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]);
}
// Calculate and store Pf in Psh
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
const uint col = c * cols_per_iter + col_tid;
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; r += 4) {
const uint row = tile_row(r);
if (KV_bounds_check && j * Bc + col >= KV) {
Psh[col * psh_stride + row / 4] = f16vec4(0.0f);
} else {
const vec4 mfvec = vec4(Mf[r], Mf[r + 1], Mf[r + 2], Mf[r + 3]);
const f16vec4 Pf = f16vec4(exp(vec4(sfsh[row / 4 + col * sfshstride]) - mfvec));
[[unroll]] for (uint32_t vec_idx = 0; vec_idx < 4; ++vec_idx) {
Lf[r + vec_idx] += Pf[vec_idx];
}
Psh[col * psh_stride + row / 4] = Pf;
}
}
}
const uint num_hsv_tiles = (HSV + MatBc * row_split - 1) / (MatBc * row_split); // round up
// Each subgroup handles HSV/4 columns
[[unroll]] for (uint32_t hsv_tile = 0; hsv_tile < num_hsv_tiles; ++hsv_tile) {
const uint hsv_offset = (hsv_tile * row_split + gl_SubgroupID) * 16;
SfMat = coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator>(0);
// Preload V tiles for [Bc, 16 * num subgroups]
const uint v_rows = Bc;
const uint v_total = v_rows * v_cols;
const uint v_loads_per_thread = v_total / gl_WorkGroupSize.x;
#if BLOCK_SIZE == 1
// For f16, only preload if not aligned
if (KV_bounds_check) {
#endif
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Of[r][d] += ACC_TYPE(Pf[r]) * ACC_TYPEV4(Vf);
[[unroll]] for (uint32_t i = 0; i < v_loads_per_thread; ++i) {
const uint idx = i * gl_WorkGroupSize.x + tid;
const uint row = idx / v_cols;
const uint col = idx % v_cols;
const uint v_row = j * Bc + row;
const uint v_col = hsv_tile * MatBc * row_split + col * 4;
const uint coord = v_row * v_stride * BLOCK_SIZE + v_col;
const uint ib = coord / BLOCK_SIZE;
const uint iqs = coord % BLOCK_SIZE;
if (!KV_bounds_check || (v_row < KV && v_col < HSV)) {
#if BLOCK_SIZE > 1
ksh[row * vsh_stride + col] = f16vec4(dequantize4(ib, iqs, v_offset, BINDING_IDX_V));
#else
ksh[row * vsh_stride + col] = data_vv4[(v_offset + v_row * v_stride + v_col) / 4];
#endif
} else {
ksh[row * vsh_stride + col] = f16vec4(0.0f);
}
}
#if BLOCK_SIZE == 1
}
#endif
barrier();
[[unroll]] for (uint32_t bc_chunk = 0; bc_chunk < Bc / MatBc; ++bc_chunk) {
coopMatLoad(KMat, Psh, bc_chunk * MatBc * psh_stride, psh_stride, gl_CooperativeMatrixLayoutColumnMajor);
#if BLOCK_SIZE == 1
if (!KV_bounds_check) {
// F16 values can be loaded directly from global memory
const uint v_tile_row = j * Bc + bc_chunk * MatBc;
const uint v_tile_offset = v_offset / 4 + v_tile_row * v_stride / 4 + hsv_offset / 4;
coopMatLoad(QMat, data_vv4, v_tile_offset, v_stride / 4, gl_CooperativeMatrixLayoutRowMajor);
} else
#endif
{
const uint v_tile_offset = bc_chunk * MatBr * v_cols + gl_SubgroupID * (MatBc / 4);
coopMatLoad(QMat, ksh, v_tile_offset, vsh_stride, gl_CooperativeMatrixLayoutRowMajor);
}
SfMat = coopMatMulAdd(KMat, QMat, SfMat);
}
// Store SfMat to sfsh and load into Of
const uint osh_stride = row_split * MatBc / 4;
const uint o_offset = gl_SubgroupID * MatBc / 4;
coopMatStore(SfMat, sfsh, o_offset, osh_stride, gl_CooperativeMatrixLayoutRowMajor);
barrier();
const uint hsv_per_tile = row_split * MatBc;
const uint hsv_base = hsv_tile * hsv_per_tile;
const uint d_values_per_tile = hsv_per_tile / 4;
const uint d_start = hsv_tile * d_values_per_tile;
const uint d_end = min(d_start + d_values_per_tile, HSV / 4);
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
const uint row = tile_row(r);
[[unroll]] for (uint32_t d_local = 0; d_local < d_per_thread; ++d_local) {
const uint d = d_local * threads_per_rowgroup + col_tid;
const uint hsv_col = 4 * d;
if (hsv_col >= hsv_base && hsv_col < hsv_base + hsv_per_tile && hsv_col < HSV) {
const uint local_hsv = (hsv_col - hsv_base) / 4;
Of[r][d_local] += ACC_TYPEV4(sfsh[row * osh_stride + local_hsv]);
}
}
}
}
@ -302,69 +468,8 @@ void main() {
barrier();
}
// prevent race on tmpsh
barrier();
// reduce across threads
float rowmaxf[rows_per_thread], eMf[rows_per_thread], Moldf[rows_per_thread];
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
FLOAT_TYPE M = Mf[r];
tmpsh[tid] = M;
// Compute max across the row
barrier();
[[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) {
M = max(M, tmpsh[tid ^ s]);
barrier();
tmpsh[tid] = M;
barrier();
}
rowmaxf[r] = tmpsh[d_tid + row_tid * threads_per_rowgroup];
barrier();
}
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Moldf[r] = Mf[r];
// M = max(rowmax, Mold)
// eM = e^(Mold - M)
Mf[r] = max(rowmaxf[r], Moldf[r]);
eMf[r] = exp(Moldf[r] - Mf[r]);
Lf[r] = eMf[r]*Lf[r];
}
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
FLOAT_TYPE L = Lf[r];
tmpsh[tid] = L;
// Compute sum across the row
barrier();
[[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) {
L += tmpsh[tid ^ s];
barrier();
tmpsh[tid] = L;
barrier();
}
Lf[r] = tmpsh[d_tid + row_tid * threads_per_rowgroup];
barrier();
}
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
Of[r][d] = ACC_TYPE(eMf[r]) * Of[r][d];
tmpshv4[tid] = Of[r][d];
barrier();
[[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) {
Of[r][d] += tmpshv4[tid ^ s];
barrier();
tmpshv4[tid] = Of[r][d];
barrier();
}
Of[r][d] = tmpshv4[d_tid + row_tid * threads_per_rowgroup];
barrier();
}
Lf[r] = subgroupAdd(Lf[r]);
}
// If there is split_k, then the split_k resolve shader does the final
@ -375,9 +480,12 @@ void main() {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
if (tile_row(r) < N) {
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
const uint d = d0 + col_tid;
if (d >= HSV/4) break;
const uint d_local = d0 / threads_per_rowgroup;
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N);
perElemOpGqaStore(tile_row(r), 4 * d + comp, float(Of[r][d_local][comp]), o_offset, iq2, N);
}
}
}
@ -404,8 +512,9 @@ void main() {
if (sink > Mf[r]) {
ms = exp(Mf[r] - sink);
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
Of[r][d] *= ACC_TYPE(ms);
[[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
const uint d_local = d0 / threads_per_rowgroup;
Of[r][d_local] *= ACC_TYPE(ms);
}
} else {
vs = exp(sink - Mf[r]);
@ -420,11 +529,12 @@ void main() {
Lfrcp[r] = (Lf[r] == 0.0) ? 0.0 : (1.0 / Lf[r]);
}
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
const uint d_local = d0 / threads_per_rowgroup;
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Of[r][d] *= ACC_TYPE(Lfrcp[r]);
Of[r][d_local] *= ACC_TYPE(Lfrcp[r]);
#if defined(ACC_TYPE_MAX)
Of[r][d] = clamp(Of[r][d], -ACC_TYPE_MAX, ACC_TYPE_MAX);
Of[r][d_local] = clamp(Of[r][d_local], -ACC_TYPE_MAX, ACC_TYPE_MAX);
#endif
}
}
@ -434,9 +544,12 @@ void main() {
if (p.gqa_ratio > 1) {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
if (tile_row(r) < N) {
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
const uint d = d0 + col_tid;
if (d >= HSV / 4) break;
const uint d_local = d0 / threads_per_rowgroup;
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N);
perElemOpGqaStore(tile_row(r), 4 * d + comp, float(Of[r][d_local][comp]), o_offset, iq2, N);
}
}
}
@ -444,9 +557,12 @@ void main() {
} else {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
if (i * Br + tile_row(r) < N) {
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
const uint d = d0 + col_tid;
if (d >= HSV / 4) break;
const uint d_local = d0 / threads_per_rowgroup;
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
data_o[o_offset + iq2 * HSV + (i * Br + tile_row(r)) * p.ne1 * HSV + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
data_o[o_offset + iq2 * HSV + (i * Br + tile_row(r)) * p.ne1 * HSV + 4 * d + comp] = D_TYPE(Of[r][d_local][comp]);
}
}
}

View file

@ -55,7 +55,7 @@ ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE ele
return max(elem0, elem1);
}
#if defined(BLOCK_SIZE)
#if BLOCK_SIZE > 1
#define DECODEFUNC , DEQUANTFUNC
#else
#define DECODEFUNC
@ -85,7 +85,7 @@ void main() {
tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0);
#if defined(BLOCK_SIZE)
#if BLOCK_SIZE > 1
tensorLayoutK = setTensorLayoutBlockSizeNV(tensorLayoutK, 1, BLOCK_SIZE);
tensorLayoutV = setTensorLayoutBlockSizeNV(tensorLayoutV, 1, BLOCK_SIZE);
#endif
@ -98,7 +98,7 @@ void main() {
if (Clamp != gl_CooperativeMatrixClampModeConstantNV)
{
q_stride &= ~7;
#if !defined(BLOCK_SIZE)
#if BLOCK_SIZE == 1
k_stride &= ~7;
v_stride &= ~7;
#endif

View file

@ -329,6 +329,12 @@ static void test_loops(testing & t) {
"empty"
);
test_template(t, "for undefined empty",
"{% for i in items %}{{ i }}{% else %}empty{% endfor %}",
json::object(),
"empty"
);
test_template(t, "nested for",
"{% for i in a %}{% for j in b %}{{ i }}{{ j }}{% endfor %}{% endfor %}",
{{"a", json::array({1, 2})}, {"b", json::array({"x", "y"})}},
@ -1018,6 +1024,18 @@ static void test_tests(testing & t) {
{{"x", {{"a", 1}}}},
"yes"
);
test_template(t, "undefined is sequence",
"{{ 'yes' if x is sequence }}",
json::object(),
"yes"
);
test_template(t, "undefined is iterable",
"{{ 'yes' if x is iterable }}",
json::object(),
"yes"
);
}
static void test_string_methods(testing & t) {
@ -1122,6 +1140,54 @@ static void test_string_methods(testing & t) {
{{"s", "banana"}},
"bXnXna"
);
test_template(t, "undefined|capitalize",
"{{ arr|capitalize }}",
json::object(),
""
);
test_template(t, "undefined|title",
"{{ arr|title }}",
json::object(),
""
);
test_template(t, "undefined|truncate",
"{{ arr|truncate(9) }}",
json::object(),
""
);
test_template(t, "undefined|upper",
"{{ arr|upper }}",
json::object(),
""
);
test_template(t, "undefined|lower",
"{{ arr|lower }}",
json::object(),
""
);
test_template(t, "undefined|replace",
"{{ arr|replace('a', 'b') }}",
json::object(),
""
);
test_template(t, "undefined|trim",
"{{ arr|trim }}",
json::object(),
""
);
test_template(t, "undefined|wordcount",
"{{ arr|wordcount }}",
json::object(),
"0"
);
}
static void test_array_methods(testing & t) {
@ -1289,6 +1355,108 @@ static void test_array_methods(testing & t) {
// {{"arr", json::array({"a", "b", "c"})}},
// "a,x,b,c"
// );
test_template(t, "undefined|select",
"{% for item in items|select('odd') %}{{ item.name }} {% endfor %}",
json::object(),
""
);
test_template(t, "undefined|selectattr",
"{% for item in items|selectattr('active') %}{{ item.name }} {% endfor %}",
json::object(),
""
);
test_template(t, "undefined|reject",
"{% for item in items|reject('even') %}{{ item.name }} {% endfor %}",
json::object(),
""
);
test_template(t, "undefined|rejectattr",
"{% for item in items|rejectattr('active') %}{{ item.name }} {% endfor %}",
json::object(),
""
);
test_template(t, "undefined|list",
"{{ arr|list|string }}",
json::object(),
"[]"
);
test_template(t, "undefined|string",
"{{ arr|string }}",
json::object(),
""
);
test_template(t, "undefined|first",
"{{ arr|first }}",
json::object(),
""
);
test_template(t, "undefined|last",
"{{ arr|last }}",
json::object(),
""
);
test_template(t, "undefined|length",
"{{ arr|length }}",
json::object(),
"0"
);
test_template(t, "undefined|join",
"{{ arr|join }}",
json::object(),
""
);
test_template(t, "undefined|sort",
"{{ arr|sort|string }}",
json::object(),
"[]"
);
test_template(t, "undefined|reverse",
"{{ arr|reverse|join }}",
json::object(),
""
);
test_template(t, "undefined|map",
"{% for v in arr|map(attribute='age') %}{{ v }} {% endfor %}",
json::object(),
""
);
test_template(t, "undefined|min",
"{{ arr|min }}",
json::object(),
""
);
test_template(t, "undefined|max",
"{{ arr|max }}",
json::object(),
""
);
test_template(t, "undefined|unique",
"{{ arr|unique|join }}",
json::object(),
""
);
test_template(t, "undefined|sum",
"{{ arr|sum }}",
json::object(),
"0"
);
}
static void test_object_methods(testing & t) {
@ -1393,6 +1561,12 @@ static void test_object_methods(testing & t) {
json::object(),
"True"
);
test_template(t, "undefined|items",
"{{ arr|items|join }}",
json::object(),
""
);
}
static void test_hasher(testing & t) {

View file

@ -48,11 +48,8 @@ enum server_state {
struct server_slot {
int id;
llama_batch batch_spec = {};
// TODO: change to unique_ptrs for consistency:
llama_context * ctx = nullptr;
llama_context * ctx_dft = nullptr;
// multimodal
mtmd_context * mctx = nullptr;
@ -259,7 +256,7 @@ struct server_slot {
}
bool can_speculate() const {
return ctx_dft;
return !!spec;
}
void add_token(const completion_token_output & token) {
@ -295,6 +292,7 @@ struct server_slot {
SLT_DBG(*this, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, task->params.speculative.n_min);
n_draft_max = 0;
}
return n_draft_max;
}
@ -397,6 +395,8 @@ struct server_slot {
draft_ratio, n_draft_accepted, n_draft_total
);
}
common_speculative_print_stats(spec);
}
json to_json(bool only_metrics = false) const {
@ -553,18 +553,13 @@ private:
// note: keep these alive - they determine the lifetime of the model, context, etc.
common_init_result_ptr llama_init;
common_init_result_ptr llama_init_dft;
llama_context * ctx = nullptr;
bool vocab_dft_compatible = true;
llama_model * model_dft = nullptr;
llama_context_params cparams_dft;
llama_batch batch {};
llama_model_ptr model_dft;
bool add_bos_token = true;
int32_t n_ctx; // total context for all clients / slots
@ -597,13 +592,8 @@ private:
// Clear any sampling context
for (server_slot & slot : slots) {
llama_free(slot.ctx_dft);
slot.ctx_dft = nullptr;
common_speculative_free(slot.spec);
slot.spec = nullptr;
llama_batch_free(slot.batch_spec);
}
llama_batch_free(batch);
@ -648,44 +638,39 @@ private:
add_bos_token = llama_vocab_get_add_bos(vocab);
if (params_base.has_speculative()) {
SRV_INF("loading draft model '%s'\n", params_base.speculative.model.path.c_str());
if (params_base.speculative.has_dft()) {
SRV_INF("loading draft model '%s'\n", params_base.speculative.mparams_dft.path.c_str());
const auto & params_spec = params_base.speculative;
auto params_dft = params_base;
params_dft.devices = params_base.speculative.devices;
params_dft.model = params_base.speculative.model;
params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? llama_n_ctx_seq(ctx) : params_base.speculative.n_ctx;
params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers;
params_dft.n_parallel = 1;
params_dft.cache_type_k = params_base.speculative.cache_type_k;
params_dft.cache_type_v = params_base.speculative.cache_type_v;
params_dft.n_ctx = params_spec.n_ctx == 0 ? llama_n_ctx_seq(ctx) : params_spec.n_ctx;
params_dft.n_batch = llama_n_ctx_seq(ctx);
params_dft.devices = params_spec.devices;
params_dft.model = params_spec.mparams_dft;
params_dft.n_gpu_layers = params_spec.n_gpu_layers;
params_dft.cache_type_k = params_spec.cache_type_k;
params_dft.cache_type_v = params_spec.cache_type_v;
params_dft.cpuparams.n_threads = params_base.speculative.cpuparams.n_threads;
params_dft.cpuparams_batch.n_threads = params_base.speculative.cpuparams_batch.n_threads;
params_dft.tensor_buft_overrides = params_base.speculative.tensor_buft_overrides;
if (params_spec.cpuparams.n_threads > 0) {
params_dft.cpuparams.n_threads = params_spec.cpuparams.n_threads;
params_dft.cpuparams_batch.n_threads = params_spec.cpuparams_batch.n_threads;
}
llama_init_dft = common_init_from_params(params_dft);
params_dft.tensor_buft_overrides = params_spec.tensor_buft_overrides;
model_dft = llama_init_dft->model();
auto mparams_dft = common_model_params_to_llama(params_dft);
model_dft.reset(llama_model_load_from_file(params_dft.model.path.c_str(), mparams_dft));
if (model_dft == nullptr) {
SRV_ERR("failed to load draft model, '%s'\n", params_base.speculative.model.path.c_str());
SRV_ERR("failed to load draft model, '%s'\n", params_dft.model.path.c_str());
return false;
}
vocab_dft_compatible = common_speculative_are_compatible(ctx, llama_init_dft->context());
if (!vocab_dft_compatible) {
SRV_INF("the draft model '%s' is not compatible with the target model '%s'. tokens will be translated between the draft and target models.\n", params_base.speculative.model.path.c_str(), params_base.model.path.c_str());
}
const int n_ctx_dft = llama_n_ctx(llama_init_dft->context());
cparams_dft = common_context_params_to_llama(params_dft);
cparams_dft.n_batch = n_ctx_dft;
// the context is not needed - we will create one for each slot
llama_init_dft->free_context();
params_base.speculative.model_dft = model_dft.get();
params_base.speculative.cparams_dft = common_context_params_to_llama(params_dft);
}
std::string & mmproj_path = params_base.mmproj.path;
@ -695,6 +680,7 @@ private:
}
mtmd_context_params mparams = mtmd_context_params_default();
mparams.use_gpu = params_base.mmproj_use_gpu;
mparams.print_timings = false;
mparams.n_threads = params_base.cpuparams.n_threads;
@ -702,6 +688,7 @@ private:
mparams.warmup = params_base.warmup;
mparams.image_min_tokens = params_base.image_min_tokens;
mparams.image_max_tokens = params_base.image_max_tokens;
mctx = mtmd_init_from_file(mmproj_path.c_str(), model, mparams);
if (mctx == nullptr) {
SRV_ERR("failed to load multimodal model, '%s'\n", mmproj_path.c_str());
@ -718,11 +705,6 @@ private:
params_base.n_cache_reuse = 0;
SRV_WRN("%s\n", "cache_reuse is not supported by multimodal, it will be disabled");
}
if (params_base.has_speculative()) {
SRV_ERR("%s\n", "err: speculative decode is not supported by multimodal");
return false;
}
}
if (!llama_memory_can_shift(llama_get_memory(ctx))) {
@ -757,29 +739,24 @@ private:
for (int i = 0; i < params_base.n_parallel; i++) {
server_slot slot;
slot.id = i;
slot.ctx = ctx;
slot.id = i;
slot.ctx = ctx;
slot.n_ctx = n_ctx_slot;
slot.mctx = mctx;
slot.mctx = mctx;
slot.prompt.tokens.has_mtmd = mctx != nullptr;
if (model_dft) {
slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1);
// TODO: rework speculative decoding [TAG_SERVER_SPEC_REWORK]
slot.ctx_dft = llama_init_from_model(model_dft, cparams_dft);
if (slot.ctx_dft == nullptr) {
SRV_ERR("%s", "failed to create draft context\n");
return false;
}
slot.spec = common_speculative_init(slot.ctx, slot.ctx_dft);
if (slot.spec == nullptr) {
SRV_ERR("%s", "failed to create speculator\n");
return false;
}
for (auto & pair : params_base.speculative.replacements) {
common_speculative_add_replacement_tgt_dft(slot.spec, pair.first.c_str(), pair.second.c_str());
// try speculative decoding
{
slot.spec = common_speculative_init(params_base.speculative, slot.ctx);
if (slot.spec) {
if (mctx) {
SRV_ERR("%s\n", "speculative decoding is not supported with multimodal");
return false;
}
SRV_WRN("%s", "speculative decoding context initialized\n");
} else {
SRV_WRN("%s", "speculative decoding context not initialized\n");
}
}
@ -1059,7 +1036,7 @@ private:
return res;
}
std::vector<common_adapter_lora_info> construct_lora_list(const std::map<int, float> & config) {
std::vector<common_adapter_lora_info> construct_lora_list(const std::map<int, float> & config) const {
std::vector<common_adapter_lora_info> output = params_base.lora_adapters; // copy
for (size_t i = 0; i < output.size(); ++i) {
auto it = config.find(i);
@ -1162,7 +1139,7 @@ private:
backend_sampling &= task.params.sampling.backend_sampling;
// TODO: speculative decoding requires multiple samples per batch - not supported yet
backend_sampling &= !(slot.ctx_dft && task.params.speculative.n_max > 0);
backend_sampling &= !(slot.spec && task.params.speculative.n_max > 0);
// TODO: getting post/pre sampling logits is not yet supported with backend sampling
backend_sampling &= !need_logits;
@ -1179,14 +1156,6 @@ private:
slot.smpl.reset();
}
// initialize draft batch
// TODO: rework speculative decoding [TAG_SERVER_SPEC_REWORK]
if (slot.ctx_dft) {
llama_batch_free(slot.batch_spec);
slot.batch_spec = llama_batch_init(task.params.speculative.n_max + 1, 0, 1);
}
slot.task = std::make_unique<const server_task>(std::move(task));
slot.state = slot.task->is_child()
@ -2059,19 +2028,23 @@ private:
// generate draft tokens in speculative decoding mode
// TODO: rework to have a single draft llama_context shared across all slots [TAG_SERVER_SPEC_REWORK]
// perform the speculative drafting for all sequences at the same time in a single batch
int n_draft_max = slot.get_n_draft_max();
const int n_draft_max = slot.get_n_draft_max();
if (n_draft_max > 0) {
if (mctx) {
// we should never reach this, as speculative is automatically disabled if mmproj is loaded
GGML_ABORT("not supported by multimodal");
}
struct common_speculative_params params_spec;
params_spec.n_draft = n_draft_max;
params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.task->params.speculative.n_max;
params_spec.p_min = slot.task->params.speculative.p_min;
const llama_tokens & cached_text_tokens = slot.prompt.tokens.get_text_tokens();
llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, slot.sampled);
const auto & params_spec = slot.task->params.speculative;
llama_tokens draft = common_speculative_draft(slot.spec, params_spec, cached_text_tokens, slot.sampled);
if (draft.size() > (size_t) n_draft_max) {
SLT_WRN(slot, "draft size %d exceeds max %d, truncating\n", (int) draft.size(), n_draft_max);
draft.resize(n_draft_max);
}
// add the sampled token to the batch
slot.i_batch_dft.push_back(batch.n_tokens);
@ -2742,6 +2715,10 @@ private:
// prompt evaluated for next-token prediction
slot.state = SLOT_STATE_GENERATING;
if (slot.can_speculate()) {
common_speculative_begin(slot.spec, slot.prompt.tokens.get_text_tokens());
}
} else if (slot.state != SLOT_STATE_GENERATING) {
continue; // continue loop of slots
}
@ -2813,6 +2790,9 @@ private:
// update how many tokens out of those tested were accepted
slot.n_draft_accepted += ids.size() - 1;
// inform the speculative decoding about the number of accepted tokens
common_speculative_accept(slot.spec, ids.size() - 1);
// rollback to the state before sampling the draft tokens
slot.prompt.tokens.keep_first(slot.prompt.n_tokens() - n_draft);

View file

@ -5,6 +5,7 @@
#include "llama.h"
#include "chat.h"
#include "sampling.h"
#include "speculative.h"
#include "json-schema-to-grammar.h"
using json = nlohmann::ordered_json;
@ -76,6 +77,11 @@ json task_params::to_json(bool only_metrics) const {
{"speculative.n_max", speculative.n_max},
{"speculative.n_min", speculative.n_min},
{"speculative.p_min", speculative.p_min},
{"speculative.type", common_speculative_type_to_str(speculative.type)},
{"speculative.ngram_size_n", speculative.ngram_size_n},
{"speculative.ngram_size_m", speculative.ngram_size_m},
{"speculative.ngram_c_rate", speculative.ngram_check_rate},
{"speculative.ngram_m_hits", speculative.ngram_min_hits},
{"timings_per_token", timings_per_token},
{"post_sampling_probs", post_sampling_probs},
{"backend_sampling", sampling.backend_sampling},
@ -135,6 +141,11 @@ json task_params::to_json(bool only_metrics) const {
{"speculative.n_max", speculative.n_max},
{"speculative.n_min", speculative.n_min},
{"speculative.p_min", speculative.p_min},
{"speculative.type", common_speculative_type_to_str(speculative.type)},
{"speculative.ngram_size_n", speculative.ngram_size_n},
{"speculative.ngram_size_m", speculative.ngram_size_m},
{"speculative.ngram_c_rate", speculative.ngram_check_rate},
{"speculative.ngram_m_hits", speculative.ngram_min_hits},
{"timings_per_token", timings_per_token},
{"post_sampling_probs", post_sampling_probs},
{"backend_sampling", sampling.backend_sampling},
@ -242,6 +253,18 @@ task_params server_task::params_from_json_cmpl(
params.speculative.n_min = std::max(params.speculative.n_min, 0);
params.speculative.n_max = std::max(params.speculative.n_max, 0);
params.speculative.type = common_speculative_type_from_name(json_value(data, "speculative.type", common_speculative_type_to_str(defaults.speculative.type)));
params.speculative.ngram_size_n = json_value(data, "speculative.ngram_size_n", defaults.speculative.ngram_size_n);
params.speculative.ngram_size_m = json_value(data, "speculative.ngram_size_m", defaults.speculative.ngram_size_m);
params.speculative.ngram_check_rate = json_value(data, "speculative.ngram_c_rate", defaults.speculative.ngram_check_rate);
params.speculative.ngram_min_hits = json_value(data, "speculative.ngram_m_hits", defaults.speculative.ngram_min_hits);
params.speculative.ngram_size_n = std::max(std::min(1, (int) params.speculative.ngram_size_n), 1024);
params.speculative.ngram_size_m = std::max(std::min(1, (int) params.speculative.ngram_size_m), 1024);
params.speculative.ngram_check_rate = std::max(std::min(1, (int) params.speculative.ngram_check_rate), 1024);
params.speculative.ngram_min_hits = std::max(std::min(1, (int) params.speculative.ngram_min_hits), 1024);
// Use OpenAI API logprobs only if n_probs wasn't provided
if (data.contains("logprobs") && params.sampling.n_probs == defaults.sampling.n_probs){
params.sampling.n_probs = json_value(data, "logprobs", defaults.sampling.n_probs);