Merge branch 'upstream' into concedo_experimental

# Conflicts:
#	.github/ISSUE_TEMPLATE/020-enhancement.yml
#	.github/ISSUE_TEMPLATE/030-research.yml
#	.github/ISSUE_TEMPLATE/040-refactor.yml
#	.github/workflows/build.yml
#	Makefile
#	common/CMakeLists.txt
#	examples/CMakeLists.txt
#	examples/infill/infill.cpp
#	examples/lookahead/lookahead.cpp
#	examples/lookup/lookup-stats.cpp
#	examples/lookup/lookup.cpp
#	examples/parallel/parallel.cpp
#	examples/retrieval/retrieval.cpp
#	examples/save-load-state/save-load-state.cpp
#	examples/speculative/speculative.cpp
#	flake.lock
#	ggml/src/ggml-cann/CMakeLists.txt
#	ggml/src/ggml-cann/aclnn_ops.cpp
#	ggml/src/ggml-cann/kernels/CMakeLists.txt
#	ggml/src/ggml-cann/kernels/dup.cpp
#	ggml/src/ggml-cann/kernels/get_row_f16.cpp
#	ggml/src/ggml-cann/kernels/get_row_f32.cpp
#	ggml/src/ggml-cann/kernels/get_row_q4_0.cpp
#	tests/test-arg-parser.cpp
#	tests/test-backend-ops.cpp
This commit is contained in:
Concedo 2024-11-25 16:26:08 +08:00
commit 83350ec314
21 changed files with 801 additions and 377 deletions

View file

@ -236,8 +236,9 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
postprocess_cpu_params(params.cpuparams, nullptr);
postprocess_cpu_params(params.cpuparams_batch, &params.cpuparams);
postprocess_cpu_params(params.draft_cpuparams, &params.cpuparams);
postprocess_cpu_params(params.draft_cpuparams_batch, &params.cpuparams_batch);
postprocess_cpu_params(params.speculative.cpuparams, &params.cpuparams);
postprocess_cpu_params(params.speculative.cpuparams_batch, &params.cpuparams_batch);
if (params.prompt_cache_all && (params.interactive || params.interactive_first)) {
throw std::invalid_argument("error: --prompt-cache-all not supported in interactive mode yet\n");
@ -252,7 +253,7 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
for (auto & antiprompt : params.antiprompt) {
string_process_escapes(antiprompt);
}
for (auto & seq_breaker : params.sparams.dry_sequence_breakers) {
for (auto & seq_breaker : params.sampling.dry_sequence_breakers) {
string_process_escapes(seq_breaker);
}
}
@ -330,7 +331,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
std::string sampler_type_chars;
std::string sampler_type_names;
for (const auto & sampler : params.sparams.samplers) {
for (const auto & sampler : params.sampling.samplers) {
sampler_type_chars += common_sampler_type_to_chr(sampler);
sampler_type_names += common_sampler_type_to_str(sampler) + ";";
}
@ -408,26 +409,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
}
));
add_opt(common_arg(
{"-td", "--threads-draft"}, "N",
"number of threads to use during generation (default: same as --threads)",
[](common_params & params, int value) {
params.draft_cpuparams.n_threads = value;
if (params.draft_cpuparams.n_threads <= 0) {
params.draft_cpuparams.n_threads = std::thread::hardware_concurrency();
}
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
add_opt(common_arg(
{"-tbd", "--threads-batch-draft"}, "N",
"number of threads to use during batch and prompt processing (default: same as --threads-draft)",
[](common_params & params, int value) {
params.draft_cpuparams_batch.n_threads = value;
if (params.draft_cpuparams_batch.n_threads <= 0) {
params.draft_cpuparams_batch.n_threads = std::thread::hardware_concurrency();
}
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
add_opt(common_arg(
{"-C", "--cpu-mask"}, "M",
"CPU affinity mask: arbitrarily long hex. Complements cpu-range (default: \"\")",
@ -516,108 +497,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.cpuparams_batch.poll = value;
}
));
add_opt(common_arg(
{"-Cd", "--cpu-mask-draft"}, "M",
"Draft model CPU affinity mask. Complements cpu-range-draft (default: same as --cpu-mask)",
[](common_params & params, const std::string & mask) {
params.draft_cpuparams.mask_valid = true;
if (!parse_cpu_mask(mask, params.draft_cpuparams.cpumask)) {
throw std::invalid_argument("invalid cpumask");
}
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
add_opt(common_arg(
{"-Crd", "--cpu-range-draft"}, "lo-hi",
"Ranges of CPUs for affinity. Complements --cpu-mask-draft",
[](common_params & params, const std::string & range) {
params.draft_cpuparams.mask_valid = true;
if (!parse_cpu_range(range, params.draft_cpuparams.cpumask)) {
throw std::invalid_argument("invalid range");
}
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
add_opt(common_arg(
{"--cpu-strict-draft"}, "<0|1>",
"Use strict CPU placement for draft model (default: same as --cpu-strict)",
[](common_params & params, int value) {
params.draft_cpuparams.strict_cpu = value;
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
add_opt(common_arg(
{"--prio-draft"}, "N",
string_format("set draft process/thread priority : 0-normal, 1-medium, 2-high, 3-realtime (default: %d)\n", params.draft_cpuparams.priority),
[](common_params & params, int prio) {
if (prio < 0 || prio > 3) {
throw std::invalid_argument("invalid value");
}
params.draft_cpuparams.priority = (enum ggml_sched_priority) prio;
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
add_opt(common_arg(
{"--poll-draft"}, "<0|1>",
"Use polling to wait for draft model work (default: same as --poll])",
[](common_params & params, int value) {
params.draft_cpuparams.poll = value;
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
add_opt(common_arg(
{"-Cbd", "--cpu-mask-batch-draft"}, "M",
"Draft model CPU affinity mask. Complements cpu-range-draft (default: same as --cpu-mask)",
[](common_params & params, const std::string & mask) {
params.draft_cpuparams_batch.mask_valid = true;
if (!parse_cpu_mask(mask, params.draft_cpuparams_batch.cpumask)) {
throw std::invalid_argument("invalid cpumask");
}
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
add_opt(common_arg(
{"-Crbd", "--cpu-range-batch-draft"}, "lo-hi",
"Ranges of CPUs for affinity. Complements --cpu-mask-draft-batch)",
[](common_params & params, const std::string & range) {
params.draft_cpuparams_batch.mask_valid = true;
if (!parse_cpu_range(range, params.draft_cpuparams_batch.cpumask)) {
throw std::invalid_argument("invalid cpumask");
}
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
add_opt(common_arg(
{"--cpu-strict-batch-draft"}, "<0|1>",
"Use strict CPU placement for draft model (default: --cpu-strict-draft)",
[](common_params & params, int value) {
params.draft_cpuparams_batch.strict_cpu = value;
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
add_opt(common_arg(
{"--prio-batch-draft"}, "N",
string_format("set draft process/thread priority : 0-normal, 1-medium, 2-high, 3-realtime (default: %d)\n", params.draft_cpuparams_batch.priority),
[](common_params & params, int prio) {
if (prio < 0 || prio > 3) {
throw std::invalid_argument("invalid value");
}
params.draft_cpuparams_batch.priority = (enum ggml_sched_priority) prio;
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
add_opt(common_arg(
{"--poll-batch-draft"}, "<0|1>",
"Use polling to wait for draft model work (default: --poll-draft)",
[](common_params & params, int value) {
params.draft_cpuparams_batch.poll = value;
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
add_opt(common_arg(
{"--draft"}, "N",
string_format("number of tokens to draft for speculative decoding (default: %d)", params.n_draft),
[](common_params & params, int value) {
params.n_draft = value;
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_LOOKUP}));
add_opt(common_arg(
{"-ps", "--p-split"}, "N",
string_format("speculative decoding split probability (default: %.1f)", (double)params.p_split),
[](common_params & params, const std::string & value) {
params.p_split = std::stof(value);
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
add_opt(common_arg(
{"-lcs", "--lookup-cache-static"}, "FNAME",
"path to static lookup cache to use for lookup decoding (not updated by generation)",
@ -702,7 +581,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
string_format("disable internal libllama performance timings (default: %s)", params.no_perf ? "true" : "false"),
[](common_params & params) {
params.no_perf = true;
params.sparams.no_perf = true;
params.sampling.no_perf = true;
}
).set_env("LLAMA_ARG_NO_PERF"));
add_opt(common_arg(
@ -884,155 +763,155 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
string_format("samplers that will be used for generation in the order, separated by \';\'\n(default: %s)", sampler_type_names.c_str()),
[](common_params & params, const std::string & value) {
const auto sampler_names = string_split<std::string>(value, ';');
params.sparams.samplers = common_sampler_types_from_names(sampler_names, true);
params.sampling.samplers = common_sampler_types_from_names(sampler_names, true);
}
).set_sparam());
add_opt(common_arg(
{"-s", "--seed"}, "SEED",
string_format("RNG seed (default: %d, use random seed for %d)", params.sparams.seed, LLAMA_DEFAULT_SEED),
string_format("RNG seed (default: %d, use random seed for %d)", params.sampling.seed, LLAMA_DEFAULT_SEED),
[](common_params & params, const std::string & value) {
params.sparams.seed = std::stoul(value);
params.sampling.seed = std::stoul(value);
}
).set_sparam());
add_opt(common_arg(
{"--sampling-seq"}, "SEQUENCE",
string_format("simplified sequence for samplers that will be used (default: %s)", sampler_type_chars.c_str()),
[](common_params & params, const std::string & value) {
params.sparams.samplers = common_sampler_types_from_chars(value);
params.sampling.samplers = common_sampler_types_from_chars(value);
}
).set_sparam());
add_opt(common_arg(
{"--ignore-eos"},
"ignore end of stream token and continue generating (implies --logit-bias EOS-inf)",
[](common_params & params) {
params.sparams.ignore_eos = true;
params.sampling.ignore_eos = true;
}
).set_sparam());
add_opt(common_arg(
{"--penalize-nl"},
string_format("penalize newline tokens (default: %s)", params.sparams.penalize_nl ? "true" : "false"),
string_format("penalize newline tokens (default: %s)", params.sampling.penalize_nl ? "true" : "false"),
[](common_params & params) {
params.sparams.penalize_nl = true;
params.sampling.penalize_nl = true;
}
).set_sparam());
add_opt(common_arg(
{"--temp"}, "N",
string_format("temperature (default: %.1f)", (double)params.sparams.temp),
string_format("temperature (default: %.1f)", (double)params.sampling.temp),
[](common_params & params, const std::string & value) {
params.sparams.temp = std::stof(value);
params.sparams.temp = std::max(params.sparams.temp, 0.0f);
params.sampling.temp = std::stof(value);
params.sampling.temp = std::max(params.sampling.temp, 0.0f);
}
).set_sparam());
add_opt(common_arg(
{"--top-k"}, "N",
string_format("top-k sampling (default: %d, 0 = disabled)", params.sparams.top_k),
string_format("top-k sampling (default: %d, 0 = disabled)", params.sampling.top_k),
[](common_params & params, int value) {
params.sparams.top_k = value;
params.sampling.top_k = value;
}
).set_sparam());
add_opt(common_arg(
{"--top-p"}, "N",
string_format("top-p sampling (default: %.1f, 1.0 = disabled)", (double)params.sparams.top_p),
string_format("top-p sampling (default: %.1f, 1.0 = disabled)", (double)params.sampling.top_p),
[](common_params & params, const std::string & value) {
params.sparams.top_p = std::stof(value);
params.sampling.top_p = std::stof(value);
}
).set_sparam());
add_opt(common_arg(
{"--min-p"}, "N",
string_format("min-p sampling (default: %.1f, 0.0 = disabled)", (double)params.sparams.min_p),
string_format("min-p sampling (default: %.1f, 0.0 = disabled)", (double)params.sampling.min_p),
[](common_params & params, const std::string & value) {
params.sparams.min_p = std::stof(value);
params.sampling.min_p = std::stof(value);
}
).set_sparam());
add_opt(common_arg(
{"--xtc-probability"}, "N",
string_format("xtc probability (default: %.1f, 0.0 = disabled)", (double)params.sparams.xtc_probability),
string_format("xtc probability (default: %.1f, 0.0 = disabled)", (double)params.sampling.xtc_probability),
[](common_params & params, const std::string & value) {
params.sparams.xtc_probability = std::stof(value);
params.sampling.xtc_probability = std::stof(value);
}
).set_sparam());
add_opt(common_arg(
{"--xtc-threshold"}, "N",
string_format("xtc threshold (default: %.1f, 1.0 = disabled)", (double)params.sparams.xtc_threshold),
string_format("xtc threshold (default: %.1f, 1.0 = disabled)", (double)params.sampling.xtc_threshold),
[](common_params & params, const std::string & value) {
params.sparams.xtc_threshold = std::stof(value);
params.sampling.xtc_threshold = std::stof(value);
}
).set_sparam());
add_opt(common_arg(
{"--typical"}, "N",
string_format("locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)", (double)params.sparams.typ_p),
string_format("locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)", (double)params.sampling.typ_p),
[](common_params & params, const std::string & value) {
params.sparams.typ_p = std::stof(value);
params.sampling.typ_p = std::stof(value);
}
).set_sparam());
add_opt(common_arg(
{"--repeat-last-n"}, "N",
string_format("last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)", params.sparams.penalty_last_n),
string_format("last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)", params.sampling.penalty_last_n),
[](common_params & params, int value) {
params.sparams.penalty_last_n = value;
params.sparams.n_prev = std::max(params.sparams.n_prev, params.sparams.penalty_last_n);
params.sampling.penalty_last_n = value;
params.sampling.n_prev = std::max(params.sampling.n_prev, params.sampling.penalty_last_n);
}
).set_sparam());
add_opt(common_arg(
{"--repeat-penalty"}, "N",
string_format("penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)", (double)params.sparams.penalty_repeat),
string_format("penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)", (double)params.sampling.penalty_repeat),
[](common_params & params, const std::string & value) {
params.sparams.penalty_repeat = std::stof(value);
params.sampling.penalty_repeat = std::stof(value);
}
).set_sparam());
add_opt(common_arg(
{"--presence-penalty"}, "N",
string_format("repeat alpha presence penalty (default: %.1f, 0.0 = disabled)", (double)params.sparams.penalty_present),
string_format("repeat alpha presence penalty (default: %.1f, 0.0 = disabled)", (double)params.sampling.penalty_present),
[](common_params & params, const std::string & value) {
params.sparams.penalty_present = std::stof(value);
params.sampling.penalty_present = std::stof(value);
}
).set_sparam());
add_opt(common_arg(
{"--frequency-penalty"}, "N",
string_format("repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)", (double)params.sparams.penalty_freq),
string_format("repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)", (double)params.sampling.penalty_freq),
[](common_params & params, const std::string & value) {
params.sparams.penalty_freq = std::stof(value);
params.sampling.penalty_freq = std::stof(value);
}
).set_sparam());
add_opt(common_arg(
{"--dry-multiplier"}, "N",
string_format("set DRY sampling multiplier (default: %.1f, 0.0 = disabled)", (double)params.sparams.dry_multiplier),
string_format("set DRY sampling multiplier (default: %.1f, 0.0 = disabled)", (double)params.sampling.dry_multiplier),
[](common_params & params, const std::string & value) {
params.sparams.dry_multiplier = std::stof(value);
params.sampling.dry_multiplier = std::stof(value);
}
).set_sparam());
add_opt(common_arg(
{"--dry-base"}, "N",
string_format("set DRY sampling base value (default: %.2f)", (double)params.sparams.dry_base),
string_format("set DRY sampling base value (default: %.2f)", (double)params.sampling.dry_base),
[](common_params & params, const std::string & value) {
float potential_base = std::stof(value);
if (potential_base >= 1.0f)
{
params.sparams.dry_base = potential_base;
params.sampling.dry_base = potential_base;
}
}
).set_sparam());
add_opt(common_arg(
{"--dry-allowed-length"}, "N",
string_format("set allowed length for DRY sampling (default: %d)", params.sparams.dry_allowed_length),
string_format("set allowed length for DRY sampling (default: %d)", params.sampling.dry_allowed_length),
[](common_params & params, int value) {
params.sparams.dry_allowed_length = value;
params.sampling.dry_allowed_length = value;
}
).set_sparam());
add_opt(common_arg(
{"--dry-penalty-last-n"}, "N",
string_format("set DRY penalty for the last n tokens (default: %d, 0 = disable, -1 = context size)", params.sparams.dry_penalty_last_n),
string_format("set DRY penalty for the last n tokens (default: %d, 0 = disable, -1 = context size)", params.sampling.dry_penalty_last_n),
[](common_params & params, int value) {
params.sparams.dry_penalty_last_n = value;
params.sampling.dry_penalty_last_n = value;
}
).set_sparam());
add_opt(common_arg(
{"--dry-sequence-breaker"}, "STRING",
string_format("add sequence breaker for DRY sampling, clearing out default breakers (%s) in the process; use \"none\" to not use any sequence breakers\n",
params.sparams.dry_sequence_breakers.empty() ? "none" :
std::accumulate(std::next(params.sparams.dry_sequence_breakers.begin()),
params.sparams.dry_sequence_breakers.end(),
std::string("'") + (params.sparams.dry_sequence_breakers[0] == "\n" ? "\\n" : params.sparams.dry_sequence_breakers[0]) + "'",
params.sampling.dry_sequence_breakers.empty() ? "none" :
std::accumulate(std::next(params.sampling.dry_sequence_breakers.begin()),
params.sampling.dry_sequence_breakers.end(),
std::string("'") + (params.sampling.dry_sequence_breakers[0] == "\n" ? "\\n" : params.sampling.dry_sequence_breakers[0]) + "'",
[](const std::string& a, const std::string& b) {
std::string formatted_b = (b == "\n") ? "\\n" : b;
return a + ", '" + formatted_b + "'";
@ -1041,51 +920,51 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
static bool defaults_cleared = false;
if (!defaults_cleared) {
params.sparams.dry_sequence_breakers.clear();
params.sampling.dry_sequence_breakers.clear();
defaults_cleared = true;
}
if (value == "none") {
params.sparams.dry_sequence_breakers.clear();
params.sampling.dry_sequence_breakers.clear();
} else {
params.sparams.dry_sequence_breakers.emplace_back(value);
params.sampling.dry_sequence_breakers.emplace_back(value);
}
}
).set_sparam());
add_opt(common_arg(
{"--dynatemp-range"}, "N",
string_format("dynamic temperature range (default: %.1f, 0.0 = disabled)", (double)params.sparams.dynatemp_range),
string_format("dynamic temperature range (default: %.1f, 0.0 = disabled)", (double)params.sampling.dynatemp_range),
[](common_params & params, const std::string & value) {
params.sparams.dynatemp_range = std::stof(value);
params.sampling.dynatemp_range = std::stof(value);
}
).set_sparam());
add_opt(common_arg(
{"--dynatemp-exp"}, "N",
string_format("dynamic temperature exponent (default: %.1f)", (double)params.sparams.dynatemp_exponent),
string_format("dynamic temperature exponent (default: %.1f)", (double)params.sampling.dynatemp_exponent),
[](common_params & params, const std::string & value) {
params.sparams.dynatemp_exponent = std::stof(value);
params.sampling.dynatemp_exponent = std::stof(value);
}
).set_sparam());
add_opt(common_arg(
{"--mirostat"}, "N",
string_format("use Mirostat sampling.\nTop K, Nucleus and Locally Typical samplers are ignored if used.\n"
"(default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)", params.sparams.mirostat),
"(default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)", params.sampling.mirostat),
[](common_params & params, int value) {
params.sparams.mirostat = value;
params.sampling.mirostat = value;
}
).set_sparam());
add_opt(common_arg(
{"--mirostat-lr"}, "N",
string_format("Mirostat learning rate, parameter eta (default: %.1f)", (double)params.sparams.mirostat_eta),
string_format("Mirostat learning rate, parameter eta (default: %.1f)", (double)params.sampling.mirostat_eta),
[](common_params & params, const std::string & value) {
params.sparams.mirostat_eta = std::stof(value);
params.sampling.mirostat_eta = std::stof(value);
}
).set_sparam());
add_opt(common_arg(
{"--mirostat-ent"}, "N",
string_format("Mirostat target entropy, parameter tau (default: %.1f)", (double)params.sparams.mirostat_tau),
string_format("Mirostat target entropy, parameter tau (default: %.1f)", (double)params.sampling.mirostat_tau),
[](common_params & params, const std::string & value) {
params.sparams.mirostat_tau = std::stof(value);
params.sampling.mirostat_tau = std::stof(value);
}
).set_sparam());
add_opt(common_arg(
@ -1101,7 +980,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
try {
if (ss >> key && ss >> sign && std::getline(ss, value_str) && (sign == '+' || sign == '-')) {
const float bias = std::stof(value_str) * ((sign == '-') ? -1.0f : 1.0f);
params.sparams.logit_bias.push_back({key, bias});
params.sampling.logit_bias.push_back({key, bias});
} else {
throw std::invalid_argument("invalid input format");
}
@ -1112,9 +991,9 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
).set_sparam());
add_opt(common_arg(
{"--grammar"}, "GRAMMAR",
string_format("BNF-like grammar to constrain generations (see samples in grammars/ dir) (default: '%s')", params.sparams.grammar.c_str()),
string_format("BNF-like grammar to constrain generations (see samples in grammars/ dir) (default: '%s')", params.sampling.grammar.c_str()),
[](common_params & params, const std::string & value) {
params.sparams.grammar = value;
params.sampling.grammar = value;
}
).set_sparam());
add_opt(common_arg(
@ -1128,7 +1007,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
std::copy(
std::istreambuf_iterator<char>(file),
std::istreambuf_iterator<char>(),
std::back_inserter(params.sparams.grammar)
std::back_inserter(params.sampling.grammar)
);
}
).set_sparam());
@ -1136,7 +1015,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
{"-j", "--json-schema"}, "SCHEMA",
"JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object\nFor schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead",
[](common_params & params, const std::string & value) {
params.sparams.grammar = json_schema_to_grammar(json::parse(value));
params.sampling.grammar = json_schema_to_grammar(json::parse(value));
}
).set_sparam());
add_opt(common_arg(
@ -1445,17 +1324,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
}
).set_env("LLAMA_ARG_N_GPU_LAYERS"));
add_opt(common_arg(
{"-ngld", "--gpu-layers-draft", "--n-gpu-layers-draft"}, "N",
"number of layers to store in VRAM for the draft model",
[](common_params & params, int value) {
params.n_gpu_layers_draft = value;
if (!llama_supports_gpu_offload()) {
fprintf(stderr, "warning: not compiled with GPU offload support, --gpu-layers-draft option will be ignored\n");
fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n");
}
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
add_opt(common_arg(
{"-sm", "--split-mode"}, "{none,layer,row}",
"how to split the model across multiple GPUs, one of:\n"
@ -1594,13 +1462,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.model = value;
}
).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA}).set_env("LLAMA_ARG_MODEL"));
add_opt(common_arg(
{"-md", "--model-draft"}, "FNAME",
"draft model for speculative decoding (default: unused)",
[](common_params & params, const std::string & value) {
params.model_draft = value;
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
add_opt(common_arg(
{"-mu", "--model-url"}, "MODEL_URL",
"model download url (default: unused)",
@ -2038,5 +1899,168 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
).set_env("LLAMA_LOG_TIMESTAMPS"));
// speculative parameters
add_opt(common_arg(
{"-td", "--threads-draft"}, "N",
"number of threads to use during generation (default: same as --threads)",
[](common_params & params, int value) {
params.speculative.cpuparams.n_threads = value;
if (params.speculative.cpuparams.n_threads <= 0) {
params.speculative.cpuparams.n_threads = std::thread::hardware_concurrency();
}
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
add_opt(common_arg(
{"-tbd", "--threads-batch-draft"}, "N",
"number of threads to use during batch and prompt processing (default: same as --threads-draft)",
[](common_params & params, int value) {
params.speculative.cpuparams_batch.n_threads = value;
if (params.speculative.cpuparams_batch.n_threads <= 0) {
params.speculative.cpuparams_batch.n_threads = std::thread::hardware_concurrency();
}
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
add_opt(common_arg(
{"-Cd", "--cpu-mask-draft"}, "M",
"Draft model CPU affinity mask. Complements cpu-range-draft (default: same as --cpu-mask)",
[](common_params & params, const std::string & mask) {
params.speculative.cpuparams.mask_valid = true;
if (!parse_cpu_mask(mask, params.speculative.cpuparams.cpumask)) {
throw std::invalid_argument("invalid cpumask");
}
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
add_opt(common_arg(
{"-Crd", "--cpu-range-draft"}, "lo-hi",
"Ranges of CPUs for affinity. Complements --cpu-mask-draft",
[](common_params & params, const std::string & range) {
params.speculative.cpuparams.mask_valid = true;
if (!parse_cpu_range(range, params.speculative.cpuparams.cpumask)) {
throw std::invalid_argument("invalid range");
}
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
add_opt(common_arg(
{"--cpu-strict-draft"}, "<0|1>",
"Use strict CPU placement for draft model (default: same as --cpu-strict)",
[](common_params & params, int value) {
params.speculative.cpuparams.strict_cpu = value;
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
add_opt(common_arg(
{"--prio-draft"}, "N",
string_format("set draft process/thread priority : 0-normal, 1-medium, 2-high, 3-realtime (default: %d)\n", params.speculative.cpuparams.priority),
[](common_params & params, int prio) {
if (prio < 0 || prio > 3) {
throw std::invalid_argument("invalid value");
}
params.speculative.cpuparams.priority = (enum ggml_sched_priority) prio;
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
add_opt(common_arg(
{"--poll-draft"}, "<0|1>",
"Use polling to wait for draft model work (default: same as --poll])",
[](common_params & params, int value) {
params.speculative.cpuparams.poll = value;
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
add_opt(common_arg(
{"-Cbd", "--cpu-mask-batch-draft"}, "M",
"Draft model CPU affinity mask. Complements cpu-range-draft (default: same as --cpu-mask)",
[](common_params & params, const std::string & mask) {
params.speculative.cpuparams_batch.mask_valid = true;
if (!parse_cpu_mask(mask, params.speculative.cpuparams_batch.cpumask)) {
throw std::invalid_argument("invalid cpumask");
}
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
add_opt(common_arg(
{"-Crbd", "--cpu-range-batch-draft"}, "lo-hi",
"Ranges of CPUs for affinity. Complements --cpu-mask-draft-batch)",
[](common_params & params, const std::string & range) {
params.speculative.cpuparams_batch.mask_valid = true;
if (!parse_cpu_range(range, params.speculative.cpuparams_batch.cpumask)) {
throw std::invalid_argument("invalid cpumask");
}
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
add_opt(common_arg(
{"--cpu-strict-batch-draft"}, "<0|1>",
"Use strict CPU placement for draft model (default: --cpu-strict-draft)",
[](common_params & params, int value) {
params.speculative.cpuparams_batch.strict_cpu = value;
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
add_opt(common_arg(
{"--prio-batch-draft"}, "N",
string_format("set draft process/thread priority : 0-normal, 1-medium, 2-high, 3-realtime (default: %d)\n", params.speculative.cpuparams_batch.priority),
[](common_params & params, int prio) {
if (prio < 0 || prio > 3) {
throw std::invalid_argument("invalid value");
}
params.speculative.cpuparams_batch.priority = (enum ggml_sched_priority) prio;
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
add_opt(common_arg(
{"--poll-batch-draft"}, "<0|1>",
"Use polling to wait for draft model work (default: --poll-draft)",
[](common_params & params, int value) {
params.speculative.cpuparams_batch.poll = value;
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
add_opt(common_arg(
{"--draft-max", "--draft", "--draft-n"}, "N",
string_format("number of tokens to draft for speculative decoding (default: %d)", params.speculative.n_max),
[](common_params & params, int value) {
params.speculative.n_max = value;
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_LOOKUP, LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"--draft-min", "--draft-n-min"}, "N",
string_format("minimum number of draft tokens to use for speculative decoding (default: %d)", params.speculative.n_min),
[](common_params & params, int value) {
params.speculative.n_min = value;
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_LOOKUP, LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"--draft-p-split"}, "P",
string_format("speculative decoding split probability (default: %.1f)", (double)params.speculative.p_split),
[](common_params & params, const std::string & value) {
params.speculative.p_split = std::stof(value);
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
add_opt(common_arg(
{"--draft-p-min"}, "P",
string_format("minimum speculative decoding probability (greedy) (default: %.1f)", (double)params.speculative.p_min),
[](common_params & params, const std::string & value) {
params.speculative.p_min = std::stof(value);
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"-cd", "--ctx-size-draft"}, "N",
string_format("size of the prompt context for the draft model (default: %d, 0 = loaded from model)", params.speculative.n_ctx),
[](common_params & params, int value) {
params.speculative.n_ctx = value;
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"-ngld", "--gpu-layers-draft", "--n-gpu-layers-draft"}, "N",
"number of layers to store in VRAM for the draft model",
[](common_params & params, int value) {
params.speculative.n_gpu_layers = value;
if (!llama_supports_gpu_offload()) {
fprintf(stderr, "warning: not compiled with GPU offload support, --gpu-layers-draft option will be ignored\n");
fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n");
}
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"-md", "--model-draft"}, "FNAME",
"draft model for speculative decoding (default: unused)",
[](common_params & params, const std::string & value) {
params.speculative.model = value;
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}));
return ctx_arg;
}

View file

@ -539,11 +539,11 @@ std::string string_from(const struct llama_context * ctx, const struct llama_bat
detokenized.end());
buf << "\n" << std::to_string(i)
<< ":token '" << detokenized << "'"
<< ":pos " << std::to_string(batch.pos[i])
<< ":n_seq_id " << std::to_string(batch.n_seq_id[i])
<< ":seq_id " << std::to_string(batch.seq_id[i][0])
<< ":logits " << std::to_string(batch.logits[i]);
<< ", token '" << detokenized << "'"
<< ", pos " << std::to_string(batch.pos[i])
<< ", n_seq_id " << std::to_string(batch.n_seq_id[i])
<< ", seq_id " << std::to_string(batch.seq_id[i][0])
<< ", logits " << std::to_string(batch.logits[i]);
}
buf << " ]";
@ -927,9 +927,9 @@ struct common_init_result common_init_from_params(common_params & params) {
common_lora_adapters_apply(lctx, iparams.lora_adapters);
}
if (params.sparams.ignore_eos && llama_token_eos(model) == LLAMA_TOKEN_NULL) {
if (params.sampling.ignore_eos && llama_token_eos(model) == LLAMA_TOKEN_NULL) {
LOG_WRN("%s: warning: model does not have an EOS token, ignoring --ignore-eos\n", __func__);
params.sparams.ignore_eos = false;
params.sampling.ignore_eos = false;
}
if (params.warmup) {
@ -1492,6 +1492,66 @@ void common_batch_add(
batch.n_tokens++;
}
//
// Token utils
//
size_t common_lcp(const llama_tokens & a, const llama_tokens & b) {
size_t i;
for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {}
return i;
}
size_t common_lcs(const llama_tokens & a, const llama_tokens & b) {
// check for empty sequences
if (a.empty() || b.empty()) {
return 0;
}
// get the lengths of the input sequences
size_t a_len = a.size();
size_t b_len = b.size();
// initialize the maximum length of the longest common subsequence (LCS)
size_t max_length = 0;
// use two rows instead of a 2D matrix to optimize space
std::vector<size_t> prev_row(b_len + 1, 0);
std::vector<size_t> curr_row(b_len + 1, 0);
// iterate through the elements of a
for (size_t i = 1; i <= a_len; i++) {
// iterate through the elements of b
for (size_t j = 1; j <= b_len; j++) {
// if elements at the current positions match
if (a[i - 1] == b[j - 1]) {
// if it's the first element of either sequences, set LCS length to 1
if (i == 1 || j == 1) {
curr_row[j] = 1;
} else {
// increment LCS length by 1 compared to the previous element
curr_row[j] = prev_row[j - 1] + 1;
}
// update max_length if necessary
if (curr_row[j] > max_length) {
max_length = curr_row[j];
}
} else {
// reset LCS length if elements don't match
curr_row[j] = 0;
}
}
// update the previous row for the next iteration
prev_row = curr_row;
}
// return the maximum length of the LCS
return max_length;
}
//
// Vocab utils
//

View file

@ -33,6 +33,8 @@ struct common_lora_adapter_container : common_lora_adapter_info {
struct llama_lora_adapter * adapter;
};
using llama_tokens = std::vector<llama_token>;
// build info
struct common_control_vector_load_info;
@ -97,8 +99,8 @@ enum dimre_method {
DIMRE_METHOD_MEAN,
};
// sampler parameters
struct common_sampler_params {
// sampling parameters
struct common_params_sampling {
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler
int32_t n_prev = 64; // number of previous tokens to remember
@ -149,19 +151,30 @@ struct common_sampler_params {
std::string print() const;
};
struct common_params_speculative {
int32_t n_ctx = 0; // draft context size
int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding
int32_t n_min = 5; // minimum number of draft tokens to use for speculative decoding
int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
float p_split = 0.1f; // speculative decoding split probability
float p_min = 0.9f; // minimum speculative decoding probability (greedy)
struct cpu_params cpuparams;
struct cpu_params cpuparams_batch;
std::string model = ""; // draft model for speculative decoding // NOLINT
};
struct common_params {
int32_t n_predict = -1; // new tokens to predict
int32_t n_ctx = 4096; // context size
int32_t n_batch = 2048; // logical batch size for prompt processing (must be >=32 to use BLAS)
int32_t n_ubatch = 512; // physical batch size for prompt processing (must be >=32 to use BLAS)
int32_t n_keep = 0; // number of tokens to keep from initial prompt
int32_t n_draft = 5; // number of tokens to draft during speculative decoding
int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited)
int32_t n_parallel = 1; // number of parallel sequences to decode
int32_t n_sequences = 1; // number of sequences to decode
float p_split = 0.1f; // speculative decoding split probability
int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default)
int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs
int32_t grp_attn_n = 1; // group-attention factor
@ -178,8 +191,6 @@ struct common_params {
struct cpu_params cpuparams;
struct cpu_params cpuparams_batch;
struct cpu_params draft_cpuparams;
struct cpu_params draft_cpuparams_batch;
ggml_backend_sched_eval_callback cb_eval = nullptr;
void * cb_eval_user_data = nullptr;
@ -191,10 +202,10 @@ struct common_params {
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings
enum llama_attention_type attention_type = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type for embeddings
struct common_sampler_params sparams;
struct common_params_sampling sampling;
struct common_params_speculative speculative;
std::string model = ""; // model path // NOLINT
std::string model_draft = ""; // draft model for speculative decoding // NOLINT
std::string model_alias = "unknown"; // model alias // NOLINT
std::string model_url = ""; // model url to download // NOLINT
std::string hf_token = ""; // HF token // NOLINT
@ -457,7 +468,9 @@ struct llama_model * common_load_model_from_hf(const char * repo, const char * f
// clear LoRA adapters from context, then apply new list of adapters
void common_lora_adapters_apply(struct llama_context * ctx, std::vector<common_lora_adapter_container> & lora_adapters);
//
// Batch utils
//
void common_batch_clear(struct llama_batch & batch);
@ -468,6 +481,16 @@ void common_batch_add(
const std::vector<llama_seq_id> & seq_ids,
bool logits);
//
// Token utils
//
// longest common prefix
size_t common_lcp(const llama_tokens & a, const llama_tokens & b);
// longet common subsequence
size_t common_lcs(const llama_tokens & a, const llama_tokens & b);
//
// Vocab utils
//

View file

@ -99,7 +99,7 @@ struct ring_buffer {
};
struct common_sampler {
common_sampler_params params;
common_params_sampling params;
struct llama_sampler * grmr;
struct llama_sampler * chain;
@ -125,7 +125,7 @@ struct common_sampler {
}
};
std::string common_sampler_params::print() const {
std::string common_params_sampling::print() const {
char result[1024];
snprintf(result, sizeof(result),
@ -141,7 +141,7 @@ std::string common_sampler_params::print() const {
return std::string(result);
}
struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_sampler_params & params) {
struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params) {
llama_sampler_chain_params lparams = llama_sampler_chain_default_params();
lparams.no_perf = params.no_perf;
@ -320,6 +320,45 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
return cur_p.data[cur_p.selected].id;
}
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first) {
GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1");
std::vector<llama_token> result;
result.reserve(idxs.size());
size_t i = 0;
for (; i < draft.size(); i++) {
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
common_sampler_accept(gsmpl, id, true);
result.push_back(id);
if (draft[i] != id) {
break;
}
}
if (i == draft.size()) {
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
common_sampler_accept(gsmpl, id, true);
result.push_back(id);
}
return result;
}
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first) {
std::vector<int> idxs(draft.size() + 1);
for (size_t i = 0; i < idxs.size(); ++i) {
idxs[i] = i;
}
return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft, grammar_first);
}
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {
return llama_sampler_get_seed(gsmpl->chain);
}

View file

@ -36,7 +36,7 @@ struct common_sampler;
// llama_sampler API overloads
struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_sampler_params & params);
struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params);
void common_sampler_free(struct common_sampler * gsmpl);
@ -60,6 +60,27 @@ void common_perf_print(const struct llama_context * ctx, const struct common_sam
//
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false);
// generalized version of common_sampler_sample
//
// will cross-reference the sampled tokens with a batch of draft tokens and accept those that match
// if the sampler disagrees at some point, we stop and return the accepted tokens up to now
//
// common_sampler_sample_n(gsmpl, ctx, { idx }, {});
//
// is equivalent to
//
// common_sampler_sample(gsmpl, ctx, idx);
// common_sampler_accept(gsmpl, token, true);
//
// requires: idxs.size() == draft.size() + 1
//
// returns at least 1 token, up to idxs.size()
//
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first = false);
// assume idxs == [ 0, 1, 2, ..., draft.size() ]
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first = false);
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl);
// helpers

269
common/speculative.cpp Normal file
View file

@ -0,0 +1,269 @@
#include "speculative.h"
#include "log.h"
#include "common.h"
#include "sampling.h"
#include <cstring>
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128
#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
struct common_speculative {
struct llama_context * ctx;
struct common_sampler * smpl;
llama_batch batch;
llama_tokens prompt;
};
struct common_speculative * common_speculative_init(
struct llama_context * ctx_dft) {
auto * result = new common_speculative {
/* .ctx = */ ctx_dft,
/* .smpl = */ nullptr,
/* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1),
/* .prompt = */ {},
};
// TODO: optimize or pass from outside?
#if 0
{
common_params_sampling params;
params.no_perf = false;
params.top_k = 40;
params.top_p = 0.9;
params.samplers = {
COMMON_SAMPLER_TYPE_TOP_K,
COMMON_SAMPLER_TYPE_TOP_P,
COMMON_SAMPLER_TYPE_INFILL,
};
result->smpl = common_sampler_init(llama_get_model(ctx_dft), params);
}
#else
{
common_params_sampling params;
params.no_perf = false;
params.top_k = 10;
params.samplers = {
COMMON_SAMPLER_TYPE_TOP_K,
};
result->smpl = common_sampler_init(llama_get_model(ctx_dft), params);
}
#endif
return result;
}
void common_speculative_free(struct common_speculative * spec) {
common_sampler_free(spec->smpl);
llama_batch_free(spec->batch);
delete spec;
}
bool common_speculative_are_compatible(
const struct llama_context * ctx_tgt,
const struct llama_context * ctx_dft) {
const struct llama_model * model_tgt = llama_get_model(ctx_tgt);
const struct llama_model * model_dft = llama_get_model(ctx_dft);
const bool vocab_type_tgt = llama_vocab_type(model_tgt);
LOG_DBG("%s: vocab_type tgt: %d\n", __func__, vocab_type_tgt);
const bool vocab_type_dft = llama_vocab_type(model_dft);
LOG_DBG("%s: vocab_type dft: %d\n", __func__, vocab_type_dft);
if (vocab_type_tgt != vocab_type_dft) {
LOG_ERR("%s: draft model vocab type must match target model to use speculation but "
"vocab_type_dft = %d while vocab_type_tgt = %d\n", __func__, vocab_type_dft, vocab_type_tgt);
return false;
}
if (llama_add_bos_token(model_tgt) != llama_add_bos_token(model_dft) ||
llama_add_eos_token(model_tgt) != llama_add_eos_token(model_dft) ||
llama_token_bos(model_tgt) != llama_token_bos(model_dft) ||
llama_token_eos(model_tgt) != llama_token_eos(model_dft)
) {
LOG_ERR("%s: draft model special tokens must match target model to use speculation\n", __func__);
return false;
}
{
const int n_vocab_tgt = llama_n_vocab(model_tgt);
const int n_vocab_dft = llama_n_vocab(model_dft);
const int vocab_diff = std::abs(n_vocab_tgt - n_vocab_dft);
if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) {
LOG_ERR("%s: draft model vocab must closely match target model to use speculation but "
"target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n",
__func__, n_vocab_tgt, llama_n_vocab(model_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE);
return false;
}
for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) {
const char * token_text_tgt = llama_token_get_text(model_tgt, i);
const char * token_text_dft = llama_token_get_text(model_dft, i);
if (std::strcmp(token_text_tgt, token_text_dft) != 0) {
LOG_ERR("%s: draft model vocab must match target model to use speculation but "
"token %d content differs - target '%s', draft '%s'\n", __func__, i,
common_token_to_piece(ctx_tgt, i).c_str(),
common_token_to_piece(ctx_dft, i).c_str());
return false;
}
}
}
return true;
}
llama_tokens common_speculative_gen_draft(
struct common_speculative * spec,
struct common_speculative_params params,
const llama_tokens & prompt_tgt,
llama_token id_last) {
auto & batch = spec->batch;
auto & ctx = spec->ctx;
auto & smpl = spec->smpl;
auto & prompt = spec->prompt;
int reuse_i = 0;
int reuse_n = 0;
const int n_ctx = llama_n_ctx(ctx) - params.n_draft;
const int i_start = std::max<int>(0, (int) prompt_tgt.size() - n_ctx);
// reuse as much as possible from the old draft context
// ideally, the draft context should be as big as the target context and we will always reuse the entire prompt
for (int i = 0; i < (int) prompt.size(); ++i) {
int cur = 0;
while (i_start + cur < (int) prompt_tgt.size() &&
i + cur < (int) prompt.size() &&
prompt_tgt[i_start + cur] == prompt[i + cur]) {
cur++;
}
if ((cur >= params.n_reuse || n_ctx >= (int) prompt_tgt.size()) && cur > reuse_n) {
reuse_i = i;
reuse_n = cur;
}
}
LOG_DBG("%s: reuse_i = %d, reuse_n = %d, prompt = %d\n", __func__, reuse_i, reuse_n, (int) prompt.size());
llama_tokens result;
result.reserve(params.n_draft);
if (reuse_n == 0) {
llama_kv_cache_clear(ctx);
prompt.clear();
} else {
// this happens when a previous draft has been discarded (for example, due to being too small), but the
// target model agreed with it. in this case, we simply pass back the previous results to save compute
if (reuse_i + reuse_n < (int) prompt.size() && prompt[reuse_i + reuse_n] == id_last) {
for (int i = reuse_i + reuse_n + 1; i < (int) prompt.size(); ++i) {
result.push_back(prompt[i]);
if (params.n_draft <= (int) result.size()) {
break;
}
}
return result;
}
if (reuse_i > 0) {
llama_kv_cache_seq_rm (ctx, 0, 0, reuse_i);
llama_kv_cache_seq_add(ctx, 0, reuse_i, -1, -reuse_i);
prompt.erase(prompt.begin(), prompt.begin() + reuse_i);
}
if (reuse_n < (int) prompt.size()) {
llama_kv_cache_seq_rm (ctx, 0, reuse_n, -1);
prompt.erase(prompt.begin() + reuse_n, prompt.end());
}
}
// prepare a batch to evaluate any new tokens in the prompt
common_batch_clear(batch);
for (size_t i = i_start + reuse_n; i < prompt_tgt.size(); ++i) {
//LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_tgt[i]);
common_batch_add(batch, prompt_tgt[i], i - i_start, { 0 }, false);
prompt.push_back(prompt_tgt[i]);
}
// we should rarely end-up here during normal decoding
if (batch.n_tokens > 0) {
//LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str());
llama_decode(ctx, batch);
}
const llama_pos n_past = prompt.size();
LOG_DBG("%s: n_past = %d\n", __func__, n_past);
common_batch_clear(batch);
common_batch_add (batch, id_last, n_past, { 0 }, true);
prompt.push_back(id_last);
//LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx, prompt).c_str());
llama_decode(ctx, batch);
common_sampler_reset(smpl);
// sample n_draft tokens from the draft model
for (int i = 0; i < params.n_draft; ++i) {
common_batch_clear(batch);
common_sampler_sample(smpl, ctx, 0, true);
const auto * cur_p = common_sampler_get_candidates(smpl);
for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
k, i, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx, cur_p->data[k].id).c_str());
}
// add drafted token for each sequence
const llama_token id = cur_p->data[0].id;
// only collect very high-confidence draft tokens
if (cur_p->data[0].p < params.p_min) {
break;
}
common_sampler_accept(smpl, id, true);
result.push_back(id);
if (params.n_draft <= (int) result.size()) {
break;
}
common_batch_add(batch, id, n_past + i + 1, { 0 }, true);
// evaluate the drafted tokens on the draft model
llama_decode(ctx, batch);
prompt.push_back(id);
}
return result;
}

28
common/speculative.h Normal file
View file

@ -0,0 +1,28 @@
#pragma once
#include "llama.h"
#include "common.h"
struct common_speculative;
struct common_speculative_params {
int n_draft = 16; // max drafted tokens
int n_reuse = 256;
float p_min = 0.9f; // min probabiliy required to accept a token in the draft
};
struct common_speculative * common_speculative_init(struct llama_context * ctx_dft);
void common_speculative_free(struct common_speculative * spec);
bool common_speculative_are_compatible(
const struct llama_context * ctx_tgt,
const struct llama_context * ctx_dft);
// sample up to n_draft tokens and add them to the batch using the draft model
llama_tokens common_speculative_gen_draft(
struct common_speculative * spec,
struct common_speculative_params params,
const llama_tokens & prompt,
llama_token id_last);

View file

@ -2707,7 +2707,7 @@ class XLMRobertaModel(BertModel):
self.gguf_writer.add_token_scores(scores)
self.gguf_writer.add_token_types(toktypes)
self.gguf_writer.add_add_space_prefix(add_prefix)
self.gguf_writer.add_token_type_count(1)
self.gguf_writer.add_token_type_count(self.hparams.get("type_vocab_size", 1))
self.gguf_writer.add_remove_extra_whitespaces(remove_whitespaces)
if precompiled_charsmap:
self.gguf_writer.add_precompiled_charsmap(precompiled_charsmap)

View file

@ -68,10 +68,10 @@ int main(int argc, char ** argv) {
llama_sampler * smpl = llama_sampler_chain_init(sparams);
llama_sampler_chain_add(smpl, llama_sampler_init_top_k(params.sparams.top_k));
llama_sampler_chain_add(smpl, llama_sampler_init_top_p(params.sparams.top_p, params.sparams.min_keep));
llama_sampler_chain_add(smpl, llama_sampler_init_temp (params.sparams.temp));
llama_sampler_chain_add(smpl, llama_sampler_init_dist (params.sparams.seed));
llama_sampler_chain_add(smpl, llama_sampler_init_top_k(params.sampling.top_k));
llama_sampler_chain_add(smpl, llama_sampler_init_top_p(params.sampling.top_p, params.sampling.min_keep));
llama_sampler_chain_add(smpl, llama_sampler_init_temp (params.sampling.temp));
llama_sampler_chain_add(smpl, llama_sampler_init_dist (params.sampling.seed));
if (ctx == NULL) {
LOG_ERR("%s: error: failed to create the llama_context\n" , __func__);

View file

@ -191,7 +191,7 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_
LOG("\n");
struct common_sampler * smpl = common_sampler_init(ctx_llava->model, params->sparams);
struct common_sampler * smpl = common_sampler_init(ctx_llava->model, params->sampling);
if (!smpl) {
LOG_ERR("%s: failed to initialize sampling subsystem\n", __func__);
exit(1);

View file

@ -237,7 +237,7 @@ static struct common_sampler * llama_init(struct llava_context * ctx_llava, comm
LOG_INF("\n");
struct common_sampler * smpl = common_sampler_init(ctx_llava->model, params->sparams);
struct common_sampler * smpl = common_sampler_init(ctx_llava->model, params->sampling);
return smpl;
}

View file

@ -101,7 +101,7 @@ int main(int argc, char ** argv) {
common_init();
auto & sparams = params.sparams;
auto & sparams = params.sampling;
// save choice to use color for later
// (note for later: this is a slightly awkward choice)

View file

@ -176,7 +176,7 @@ struct server_slot {
// sampling
json json_schema;
struct common_sampler_params sparams;
struct common_params_sampling sparams;
struct common_sampler * smpl = nullptr;
llama_token sampled;
@ -688,7 +688,7 @@ struct server_context {
SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx);
slot.sparams = params.sparams;
slot.sparams = params.sampling;
slot.callback_on_release = [this](int) {
queue_tasks.pop_deferred_task();
@ -744,7 +744,7 @@ struct server_context {
}
// length of the Longest Common Subsequence between the current slot's prompt and the input prompt
int cur_lcs_len = longest_common_subsequence(slot.cache_tokens, task.prompt_tokens);
int cur_lcs_len = common_lcs(slot.cache_tokens, task.prompt_tokens);
// fraction of the common subsequence length compared to the current slot's prompt length
float cur_similarity = static_cast<float>(cur_lcs_len) / static_cast<int>(slot.cache_tokens.size());
@ -789,7 +789,7 @@ struct server_context {
bool launch_slot_with_task(server_slot & slot, const server_task & task) {
slot_params default_params;
// Sampling parameter defaults are loaded from the global server context (but individual requests can still override them)
auto default_sparams = params.sparams;
auto default_sparams = params.sampling;
const auto & data = task.data;
if (data.count("__oaicompat") != 0) {
@ -1961,7 +1961,7 @@ struct server_context {
if (slot.params.cache_prompt) {
// reuse any previously computed tokens that are common with the new prompt
slot.n_past = longest_common_prefix(slot.cache_tokens, prompt_tokens);
slot.n_past = common_lcp(slot.cache_tokens, prompt_tokens);
// reuse chunks from the cached prompt by shifting their KV cache in the new position
if (params.n_cache_reuse > 0) {

View file

@ -24,7 +24,6 @@
#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613"
using json = nlohmann::ordered_json;
using llama_tokens = std::vector<llama_token>;
#define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
#define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
@ -439,62 +438,6 @@ static std::string gen_chatcmplid() {
// other common utils
//
static size_t longest_common_prefix(const llama_tokens & a, const llama_tokens & b) {
size_t i;
for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {}
return i;
}
static size_t longest_common_subsequence(const llama_tokens & a, const llama_tokens & b) {
// check for empty sequences
if (a.empty() || b.empty()) {
return 0;
}
// get the lengths of the input sequences
size_t a_len = a.size();
size_t b_len = b.size();
// initialize the maximum length of the longest common subsequence (LCS)
size_t max_length = 0;
// use two rows instead of a 2D matrix to optimize space
std::vector<size_t> prev_row(b_len + 1, 0);
std::vector<size_t> curr_row(b_len + 1, 0);
// iterate through the elements of a
for (size_t i = 1; i <= a_len; i++) {
// iterate through the elements of b
for (size_t j = 1; j <= b_len; j++) {
// if elements at the current positions match
if (a[i - 1] == b[j - 1]) {
// if it's the first element of either sequences, set LCS length to 1
if (i == 1 || j == 1) {
curr_row[j] = 1;
} else {
// increment LCS length by 1 compared to the previous element
curr_row[j] = prev_row[j - 1] + 1;
}
// update max_length if necessary
if (curr_row[j] > max_length) {
max_length = curr_row[j];
}
} else {
// reset LCS length if elements don't match
curr_row[j] = 0;
}
}
// update the previous row for the next iteration
prev_row = curr_row;
}
// return the maximum length of the LCS
return max_length;
}
static bool ends_with(const std::string & str, const std::string & suffix) {
return str.size() >= suffix.size() && 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix);
}

View file

@ -13937,7 +13937,7 @@ int ggml_cpu_has_vsx(void) {
}
int ggml_cpu_has_neon(void) {
#if defined(__ARM_ARCH)
#if defined(__ARM_ARCH) && defined(__ARM_NEON)
return ggml_arm_arch_features.has_neon;
#else
return 0;
@ -13945,7 +13945,7 @@ int ggml_cpu_has_neon(void) {
}
int ggml_cpu_has_sve(void) {
#if defined(__ARM_ARCH)
#if defined(__ARM_ARCH) && defined(__ARM_FEATURE_SVE)
return ggml_arm_arch_features.has_sve;
#else
return 0;
@ -13953,7 +13953,7 @@ int ggml_cpu_has_sve(void) {
}
int ggml_cpu_has_matmul_int8(void) {
#if defined(__ARM_ARCH)
#if defined(__ARM_ARCH) && defined(__ARM_FEATURE_MATMUL_INT8)
return ggml_arm_arch_features.has_i8mm;
#else
return 0;
@ -13961,7 +13961,7 @@ int ggml_cpu_has_matmul_int8(void) {
}
int ggml_cpu_get_sve_cnt(void) {
#if defined(__ARM_ARCH)
#if defined(__ARM_ARCH) && defined(__ARM_FEATURE_SVE)
return ggml_arm_arch_features.sve_cnt;
#else
return 0;

View file

@ -1,57 +1,69 @@
#include "common.cuh"
#include "argmax.cuh"
#include "sum.cuh"
#include <algorithm>
#include <cstdint>
static __global__ void argmax_f32(
const float * x, int32_t * dst, const int64_t ncols, const int64_t nrows) {
#include "argmax.cuh"
#include "common.cuh"
#include "sum.cuh"
int argmax_thread = 0;
const int64_t row0 = (int64_t)blockIdx.x*WARP_SIZE;
#pragma unroll
for (int64_t row1 = 0; row1 < WARP_SIZE; ++row1) {
const int64_t row = row0 + row1;
if (row >= nrows) {
break;
}
static __global__ void argmax_f32(const float * __restrict__ x, int32_t * __restrict__ dst, const int64_t ncols) {
const int64_t row = blockIdx.x;
float maxval = -FLT_MAX;
int argmax = -1;
const float * rowx = x + row * ncols;
for (int32_t col = threadIdx.x; col < ncols; col += WARP_SIZE) {
const float val = x[row*ncols + col];
const int bigger = val > maxval;
const int not_bigger = bigger ^ 0x00000001;
maxval = maxval*not_bigger + val*bigger;
argmax = argmax*not_bigger + col*bigger;
for (int32_t col = threadIdx.x; col < ncols; col += blockDim.x) {
const float val = rowx[col];
if (val > maxval) {
maxval = val;
argmax = col;
}
}
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, mask, WARP_SIZE);
const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, mask, WARP_SIZE);
const int bigger = val > maxval;
const int not_bigger = bigger ^ 0x00000001;
maxval = maxval*not_bigger + val*bigger;
argmax = argmax*not_bigger + col*bigger;
for (int offset = 16; offset > 0; offset >>= 1) {
const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE);
const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE);
if (val > maxval) {
maxval = val;
argmax = col;
}
}
const int store = row1 == threadIdx.x;
argmax_thread += store*argmax;
const int n_warps = blockDim.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
const int warp_id = threadIdx.x / WARP_SIZE;
if (n_warps > 1) {
constexpr int max_warps = 1024 / WARP_SIZE;
__shared__ float shared_maxval[max_warps];
__shared__ int shared_argmax[max_warps];
if (lane_id == 0) {
shared_maxval[warp_id] = maxval;
shared_argmax[warp_id] = argmax;
}
const int row = row0 + threadIdx.x;
__syncthreads();
if (row >= nrows) {
return;
if (warp_id == 0) {
if (lane_id < n_warps) {
maxval = shared_maxval[lane_id];
argmax = shared_argmax[lane_id];
}
#pragma unroll
for (int offset = 16; offset > 0; offset >>= 1) {
const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE);
const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE);
if (val > maxval) {
maxval = val;
argmax = col;
}
}
}
}
dst[row] = argmax_thread;
if (warp_id == 0 && lane_id == 0) {
dst[row] = argmax;
}
}
void ggml_cuda_argmax(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@ -70,10 +82,10 @@ void ggml_cuda_argmax(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
cudaStream_t stream = ctx.stream();
const int64_t num_blocks = (nrows + WARP_SIZE - 1) / WARP_SIZE;
const dim3 blocks_dim(WARP_SIZE, 1, 1);
const int64_t num_blocks = nrows;
const int64_t num_threads = std::min<int64_t>(1024, (ne00 + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE);
const dim3 blocks_dim(num_threads, 1, 1);
const dim3 blocks_num(num_blocks, 1, 1);
argmax_f32<<<blocks_num, blocks_dim, 0, stream>>>(src0_d, dst_d, ne00, nrows);
argmax_f32<<<blocks_num, blocks_dim, 0, stream>>>(src0_d, dst_d, ne00);
}

View file

@ -180,8 +180,8 @@ static __device__ __forceinline__ int warp_reduce_sum(int x) {
return __reduce_add_sync(0xffffffff, x);
#else
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
x += __shfl_xor_sync(0xffffffff, x, mask, 32);
for (int offset = 16; offset > 0; offset >>= 1) {
x += __shfl_xor_sync(0xffffffff, x, offset, 32);
}
return x;
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_AMPERE
@ -189,17 +189,17 @@ static __device__ __forceinline__ int warp_reduce_sum(int x) {
static __device__ __forceinline__ float warp_reduce_sum(float x) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
x += __shfl_xor_sync(0xffffffff, x, mask, 32);
for (int offset = 16; offset > 0; offset >>= 1) {
x += __shfl_xor_sync(0xffffffff, x, offset, 32);
}
return x;
}
static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32);
a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32);
for (int offset = 16; offset > 0; offset >>= 1) {
a.x += __shfl_xor_sync(0xffffffff, a.x, offset, 32);
a.y += __shfl_xor_sync(0xffffffff, a.y, offset, 32);
}
return a;
}
@ -209,16 +209,16 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
const half2 a_other = __shfl_xor_sync(0xffffffff, a, mask, 32);
for (int offset = 16; offset > 0; offset >>= 1) {
const half2 a_other = __shfl_xor_sync(0xffffffff, a, offset, 32);
reinterpret_cast<half&>(a.x) += __low2half(a_other);
reinterpret_cast<half&>(a.y) += __high2half(a_other);
}
return a;
#else
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32));
for (int offset = 16; offset > 0; offset >>= 1) {
a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, offset, 32));
}
return a;
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
@ -231,8 +231,8 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
static __device__ __forceinline__ float warp_reduce_max(float x) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
for (int offset = 16; offset > 0; offset >>= 1) {
x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, offset, 32));
}
return x;
}
@ -275,8 +275,8 @@ static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const hal
static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
for (int offset = 16; offset > 0; offset >>= 1) {
x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xffffffff, x, offset, 32));
}
return x;
#else

View file

@ -69,8 +69,8 @@ static __global__ void quantize_mmq_q8_1(
// Exchange max. abs. value between vals_per_scale/4 threads.
#pragma unroll
for (int mask = vals_per_scale/8; mask > 0; mask >>= 1) {
amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, WARP_SIZE));
for (int offset = vals_per_scale/8; offset > 0; offset >>= 1) {
amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, offset, WARP_SIZE));
}
float sum;
@ -79,8 +79,8 @@ static __global__ void quantize_mmq_q8_1(
// Exchange calculate sum across vals_per_sum/4 threads.
#pragma unroll
for (int mask = vals_per_sum/8; mask > 0; mask >>= 1) {
sum += __shfl_xor_sync(0xFFFFFFFF, sum, mask, WARP_SIZE);
for (int offset = vals_per_sum/8; offset > 0; offset >>= 1) {
sum += __shfl_xor_sync(0xFFFFFFFF, sum, offset, WARP_SIZE);
}
}

View file

@ -2256,6 +2256,7 @@ struct ggml_tensor * ggml_argmax(
struct ggml_context * ctx,
struct ggml_tensor * a) {
GGML_ASSERT(ggml_is_matrix(a));
GGML_ASSERT(a->ne[0] <= INT32_MAX);
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, a->ne[1]);
@ -4139,6 +4140,7 @@ struct ggml_tensor * ggml_argsort(
struct ggml_context * ctx,
struct ggml_tensor * a,
enum ggml_sort_order order) {
GGML_ASSERT(a->ne[0] <= INT32_MAX);
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_I32, GGML_MAX_DIMS, a->ne);
ggml_set_op_params_i32(result, 0, (int32_t) order);

View file

@ -545,6 +545,9 @@ class Metadata:
gguf_writer.add_size_label(self.size_label)
if self.license is not None:
if isinstance(self.license, list):
gguf_writer.add_license(",".join(self.license))
else:
gguf_writer.add_license(self.license)
if self.license_name is not None:
gguf_writer.add_license_name(self.license_name)

View file

@ -7223,12 +7223,12 @@ static bool weight_buft_supported(const llama_hparams & hparams, ggml_tensor * w
} break;
case GGML_OP_ADD:
{
ggml_tensor * a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, w->ne[0], 512);
ggml_tensor * a = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], w->ne[1], w->ne[2], w->ne[3]);
op_tensor = ggml_add(ctx, a, w);
} break;
case GGML_OP_MUL:
{
ggml_tensor * a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, w->ne[0], 512);
ggml_tensor * a = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], w->ne[1], w->ne[2], w->ne[3]);
op_tensor = ggml_mul(ctx, a, w);
} break;
case GGML_OP_DIV:
@ -18375,13 +18375,13 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
static void llama_kv_cache_update_internal(struct llama_context & lctx) {
bool need_reserve = false;
// apply K-shift if needed
if (lctx.model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE && lctx.kv_self.has_shift) {
if (lctx.kv_self.has_shift) {
if (!llama_kv_cache_can_shift(&lctx)) {
GGML_ABORT("Deepseek2 does not support K-shift");
GGML_ABORT("The current context does not support K-shift");
}
{
// apply K-shift if needed
if (lctx.model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
ggml_backend_sched_reset(lctx.sched.get());
ggml_cgraph * gf = llama_build_graph_k_shift(lctx);
@ -20631,7 +20631,7 @@ void llama_kv_cache_update(struct llama_context * ctx) {
}
bool llama_kv_cache_can_shift(struct llama_context * ctx) {
return ctx->model.arch != LLM_ARCH_DEEPSEEK2; // not supported due to MLA
return !ctx->kv_self.recurrent && ctx->model.arch != LLM_ARCH_DEEPSEEK2; // not supported due to MLA
}
// deprecated