note: also has support for completion tokens count

This commit is contained in:
Concedo 2024-11-01 00:44:14 +08:00
commit a46f8acd03
31 changed files with 138676 additions and 137399 deletions

View file

@ -129,13 +129,13 @@ static void common_params_handle_model_default(common_params & params) {
} }
params.hf_file = params.model; params.hf_file = params.model;
} else if (params.model.empty()) { } else if (params.model.empty()) {
params.model = fs_get_cache_file(string_split(params.hf_file, '/').back()); params.model = fs_get_cache_file(string_split<std::string>(params.hf_file, '/').back());
} }
} else if (!params.model_url.empty()) { } else if (!params.model_url.empty()) {
if (params.model.empty()) { if (params.model.empty()) {
auto f = string_split(params.model_url, '#').front(); auto f = string_split<std::string>(params.model_url, '#').front();
f = string_split(f, '?').front(); f = string_split<std::string>(f, '?').front();
params.model = fs_get_cache_file(string_split(f, '/').back()); params.model = fs_get_cache_file(string_split<std::string>(f, '/').back());
} }
} else if (params.model.empty()) { } else if (params.model.empty()) {
params.model = DEFAULT_MODEL_PATH; params.model = DEFAULT_MODEL_PATH;
@ -252,6 +252,9 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
for (auto & antiprompt : params.antiprompt) { for (auto & antiprompt : params.antiprompt) {
string_process_escapes(antiprompt); string_process_escapes(antiprompt);
} }
for (auto & seq_breaker : params.sparams.dry_sequence_breakers) {
string_process_escapes(seq_breaker);
}
} }
if (!params.kv_overrides.empty()) { if (!params.kv_overrides.empty()) {
@ -880,7 +883,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
{"--samplers"}, "SAMPLERS", {"--samplers"}, "SAMPLERS",
string_format("samplers that will be used for generation in the order, separated by \';\'\n(default: %s)", sampler_type_names.c_str()), string_format("samplers that will be used for generation in the order, separated by \';\'\n(default: %s)", sampler_type_names.c_str()),
[](common_params & params, const std::string & value) { [](common_params & params, const std::string & value) {
const auto sampler_names = string_split(value, ';'); const auto sampler_names = string_split<std::string>(value, ';');
params.sparams.samplers = common_sampler_types_from_names(sampler_names, true); params.sparams.samplers = common_sampler_types_from_names(sampler_names, true);
} }
).set_sparam()); ).set_sparam());
@ -941,13 +944,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.sparams.min_p = std::stof(value); params.sparams.min_p = std::stof(value);
} }
).set_sparam()); ).set_sparam());
add_opt(common_arg(
{"--tfs"}, "N",
string_format("tail free sampling, parameter z (default: %.1f, 1.0 = disabled)", (double)params.sparams.tfs_z),
[](common_params & params, const std::string & value) {
params.sparams.tfs_z = std::stof(value);
}
).set_sparam());
add_opt(common_arg( add_opt(common_arg(
{"--xtc-probability"}, "N", {"--xtc-probability"}, "N",
string_format("xtc probability (default: %.1f, 0.0 = disabled)", (double)params.sparams.xtc_probability), string_format("xtc probability (default: %.1f, 0.0 = disabled)", (double)params.sparams.xtc_probability),
@ -998,6 +994,64 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.sparams.penalty_freq = std::stof(value); params.sparams.penalty_freq = std::stof(value);
} }
).set_sparam()); ).set_sparam());
add_opt(common_arg(
{"--dry-multiplier"}, "N",
string_format("set DRY sampling multiplier (default: %.1f, 0.0 = disabled)", (double)params.sparams.dry_multiplier),
[](common_params & params, const std::string & value) {
params.sparams.dry_multiplier = std::stof(value);
}
).set_sparam());
add_opt(common_arg(
{"--dry-base"}, "N",
string_format("set DRY sampling base value (default: %.2f)", (double)params.sparams.dry_base),
[](common_params & params, const std::string & value) {
float potential_base = std::stof(value);
if (potential_base >= 1.0f)
{
params.sparams.dry_base = potential_base;
}
}
).set_sparam());
add_opt(common_arg(
{"--dry-allowed-length"}, "N",
string_format("set allowed length for DRY sampling (default: %d)", params.sparams.dry_allowed_length),
[](common_params & params, int value) {
params.sparams.dry_allowed_length = value;
}
).set_sparam());
add_opt(common_arg(
{"--dry-penalty-last-n"}, "N",
string_format("set DRY penalty for the last n tokens (default: %d, 0 = disable, -1 = context size)", params.sparams.dry_penalty_last_n),
[](common_params & params, int value) {
params.sparams.dry_penalty_last_n = value;
}
).set_sparam());
add_opt(common_arg(
{"--dry-sequence-breaker"}, "STRING",
string_format("add sequence breaker for DRY sampling, clearing out default breakers (%s) in the process; use \"none\" to not use any sequence breakers\n",
params.sparams.dry_sequence_breakers.empty() ? "none" :
std::accumulate(std::next(params.sparams.dry_sequence_breakers.begin()),
params.sparams.dry_sequence_breakers.end(),
std::string("'") + (params.sparams.dry_sequence_breakers[0] == "\n" ? "\\n" : params.sparams.dry_sequence_breakers[0]) + "'",
[](const std::string& a, const std::string& b) {
std::string formatted_b = (b == "\n") ? "\\n" : b;
return a + ", '" + formatted_b + "'";
}).c_str()),
[](common_params & params, const std::string & value) {
static bool defaults_cleared = false;
if (!defaults_cleared) {
params.sparams.dry_sequence_breakers.clear();
defaults_cleared = true;
}
if (value == "none") {
params.sparams.dry_sequence_breakers.clear();
} else {
params.sparams.dry_sequence_breakers.emplace_back(value);
}
}
).set_sparam());
add_opt(common_arg( add_opt(common_arg(
{"--dynatemp-range"}, "N", {"--dynatemp-range"}, "N",
string_format("dynamic temperature range (default: %.1f, 0.0 = disabled)", (double)params.sparams.dynatemp_range), string_format("dynamic temperature range (default: %.1f, 0.0 = disabled)", (double)params.sparams.dynatemp_range),
@ -1014,7 +1068,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
).set_sparam()); ).set_sparam());
add_opt(common_arg( add_opt(common_arg(
{"--mirostat"}, "N", {"--mirostat"}, "N",
string_format("use Mirostat sampling.\nTop K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.\n" string_format("use Mirostat sampling.\nTop K, Nucleus and Locally Typical samplers are ignored if used.\n"
"(default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)", params.sparams.mirostat), "(default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)", params.sparams.mirostat),
[](common_params & params, int value) { [](common_params & params, int value) {
params.sparams.mirostat = value; params.sparams.mirostat = value;

View file

@ -418,19 +418,6 @@ std::string string_format(const char * fmt, ...) {
return std::string(buf.data(), size); return std::string(buf.data(), size);
} }
std::vector<std::string> string_split(std::string input, char separator) {
std::vector<std::string> parts;
size_t separator_pos = input.find(separator);
while (separator_pos != std::string::npos) {
std::string part = input.substr(0, separator_pos);
parts.emplace_back(part);
input = input.substr(separator_pos + 1);
separator_pos = input.find(separator);
}
parts.emplace_back(input);
return parts;
}
std::string string_strip(const std::string & str) { std::string string_strip(const std::string & str) {
size_t start = 0; size_t start = 0;
size_t end = str.size(); size_t end = str.size();
@ -2021,6 +2008,10 @@ void yaml_dump_non_result_info(FILE * stream, const common_params & params, cons
fprintf(stream, "chunks: %d # default: -1 (unlimited)\n", params.n_chunks); fprintf(stream, "chunks: %d # default: -1 (unlimited)\n", params.n_chunks);
fprintf(stream, "color: %s # default: false\n", params.use_color ? "true" : "false"); fprintf(stream, "color: %s # default: false\n", params.use_color ? "true" : "false");
fprintf(stream, "ctx_size: %d # default: 512\n", params.n_ctx); fprintf(stream, "ctx_size: %d # default: 512\n", params.n_ctx);
fprintf(stream, "dry_allowed_length: %d # default: 2\n", sparams.dry_allowed_length);
fprintf(stream, "dry_base: %.2f # default: 1.75\n", sparams.dry_base);
fprintf(stream, "dry_multiplier: %.1f # default: 0.0\n", sparams.dry_multiplier);
fprintf(stream, "dry_penalty_last_n: %d # default: -1 (0 = disable, -1 = context size)\n", sparams.dry_penalty_last_n);
fprintf(stream, "escape: %s # default: false\n", params.escape ? "true" : "false"); fprintf(stream, "escape: %s # default: false\n", params.escape ? "true" : "false");
fprintf(stream, "file: # never logged, see prompt instead. Can still be specified for input.\n"); fprintf(stream, "file: # never logged, see prompt instead. Can still be specified for input.\n");
fprintf(stream, "frequency_penalty: %f # default: 0.0 \n", sparams.penalty_freq); fprintf(stream, "frequency_penalty: %f # default: 0.0 \n", sparams.penalty_freq);
@ -2101,7 +2092,6 @@ void yaml_dump_non_result_info(FILE * stream, const common_params & params, cons
const std::vector<float> tensor_split_vector(params.tensor_split, params.tensor_split + llama_max_devices()); const std::vector<float> tensor_split_vector(params.tensor_split, params.tensor_split + llama_max_devices());
yaml_dump_vector_float(stream, "tensor_split", tensor_split_vector); yaml_dump_vector_float(stream, "tensor_split", tensor_split_vector);
fprintf(stream, "tfs: %f # default: 1.0\n", sparams.tfs_z);
fprintf(stream, "threads: %d # default: %u\n", params.cpuparams.n_threads, std::thread::hardware_concurrency()); fprintf(stream, "threads: %d # default: %u\n", params.cpuparams.n_threads, std::thread::hardware_concurrency());
fprintf(stream, "top_k: %d # default: 40\n", sparams.top_k); fprintf(stream, "top_k: %d # default: 40\n", sparams.top_k);
fprintf(stream, "top_p: %f # default: 0.95\n", sparams.top_p); fprintf(stream, "top_p: %f # default: 0.95\n", sparams.top_p);

View file

@ -80,14 +80,15 @@ enum llama_example {
enum common_sampler_type { enum common_sampler_type {
COMMON_SAMPLER_TYPE_NONE = 0, COMMON_SAMPLER_TYPE_NONE = 0,
COMMON_SAMPLER_TYPE_TOP_K = 1, COMMON_SAMPLER_TYPE_DRY = 1,
COMMON_SAMPLER_TYPE_TOP_P = 2, COMMON_SAMPLER_TYPE_TOP_K = 2,
COMMON_SAMPLER_TYPE_MIN_P = 3, COMMON_SAMPLER_TYPE_TOP_P = 3,
COMMON_SAMPLER_TYPE_TFS_Z = 4, COMMON_SAMPLER_TYPE_MIN_P = 4,
COMMON_SAMPLER_TYPE_TYPICAL_P = 5, //COMMON_SAMPLER_TYPE_TFS_Z = 5,
COMMON_SAMPLER_TYPE_TEMPERATURE = 6, COMMON_SAMPLER_TYPE_TYPICAL_P = 6,
COMMON_SAMPLER_TYPE_XTC = 7, COMMON_SAMPLER_TYPE_TEMPERATURE = 7,
COMMON_SAMPLER_TYPE_INFILL = 8, COMMON_SAMPLER_TYPE_XTC = 8,
COMMON_SAMPLER_TYPE_INFILL = 9,
}; };
// dimensionality reduction methods, used by cvector-generator // dimensionality reduction methods, used by cvector-generator
@ -100,34 +101,39 @@ enum dimre_method {
struct common_sampler_params { struct common_sampler_params {
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler
int32_t n_prev = 64; // number of previous tokens to remember int32_t n_prev = 64; // number of previous tokens to remember
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens
int32_t top_k = 40; // <= 0 to use vocab size int32_t top_k = 40; // <= 0 to use vocab size
float top_p = 0.95f; // 1.0 = disabled float top_p = 0.95f; // 1.0 = disabled
float min_p = 0.05f; // 0.0 = disabled float min_p = 0.05f; // 0.0 = disabled
float xtc_probability = 0.00f; // 0.0 = disabled float xtc_probability = 0.00f; // 0.0 = disabled
float xtc_threshold = 0.10f; // > 0.5 disables XTC float xtc_threshold = 0.10f; // > 0.5 disables XTC
float tfs_z = 1.00f; // 1.0 = disabled float typ_p = 1.00f; // typical_p, 1.0 = disabled
float typ_p = 1.00f; // typical_p, 1.0 = disabled float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities float dynatemp_range = 0.00f; // 0.0 = disabled
float dynatemp_range = 0.00f; // 0.0 = disabled float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size) float penalty_repeat = 1.00f; // 1.0 = disabled
float penalty_repeat = 1.00f; // 1.0 = disabled float penalty_freq = 0.00f; // 0.0 = disabled
float penalty_freq = 0.00f; // 0.0 = disabled float penalty_present = 0.00f; // 0.0 = disabled
float penalty_present = 0.00f; // 0.0 = disabled float dry_multiplier = 0.0f; // 0.0 = disabled; DRY repetition penalty for tokens extending repetition:
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 float dry_base = 1.75f; // 0.0 = disabled; multiplier * base ^ (length of sequence before token - allowed length)
float mirostat_tau = 5.00f; // target entropy int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty
float mirostat_eta = 0.10f; // learning rate int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size)
bool penalize_nl = false; // consider newlines as a repeatable token int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
bool ignore_eos = false; float mirostat_tau = 5.00f; // target entropy
bool no_perf = false; // disable performance metrics float mirostat_eta = 0.10f; // learning rate
bool penalize_nl = false; // consider newlines as a repeatable token
bool ignore_eos = false;
bool no_perf = false; // disable performance metrics
std::vector<std::string> dry_sequence_breakers = {"\n", ":", "\"", "*"}; // default sequence breakers for DRY
std::vector<enum common_sampler_type> samplers = { std::vector<enum common_sampler_type> samplers = {
COMMON_SAMPLER_TYPE_DRY,
COMMON_SAMPLER_TYPE_TOP_K, COMMON_SAMPLER_TYPE_TOP_K,
COMMON_SAMPLER_TYPE_TFS_Z,
COMMON_SAMPLER_TYPE_TYPICAL_P, COMMON_SAMPLER_TYPE_TYPICAL_P,
COMMON_SAMPLER_TYPE_TOP_P, COMMON_SAMPLER_TYPE_TOP_P,
COMMON_SAMPLER_TYPE_MIN_P, COMMON_SAMPLER_TYPE_MIN_P,
@ -376,8 +382,6 @@ bool set_process_priority(enum ggml_sched_priority prio);
LLAMA_COMMON_ATTRIBUTE_FORMAT(1, 2) LLAMA_COMMON_ATTRIBUTE_FORMAT(1, 2)
std::string string_format(const char * fmt, ...); std::string string_format(const char * fmt, ...);
std::vector<std::string> string_split(std::string input, char separator);
std::string string_strip(const std::string & str); std::string string_strip(const std::string & str);
std::string string_get_sortable_timestamp(); std::string string_get_sortable_timestamp();
@ -385,6 +389,7 @@ void string_replace_all(std::string & s, const std::string & search, const std::
template<class T> template<class T>
static std::vector<T> string_split(const std::string & str, char delim) { static std::vector<T> string_split(const std::string & str, char delim) {
static_assert(!std::is_same<T, std::string>::value, "Please use the specialized version for std::string");
std::vector<T> values; std::vector<T> values;
std::istringstream str_stream(str); std::istringstream str_stream(str);
std::string token; std::string token;
@ -397,6 +402,22 @@ static std::vector<T> string_split(const std::string & str, char delim) {
return values; return values;
} }
template<>
std::vector<std::string> string_split<std::string>(const std::string & input, char separator)
{
std::vector<std::string> parts;
size_t begin_pos = 0;
size_t separator_pos = input.find(separator);
while (separator_pos != std::string::npos) {
std::string part = input.substr(begin_pos, separator_pos - begin_pos);
parts.emplace_back(part);
begin_pos = separator_pos + 1;
separator_pos = input.find(separator, begin_pos);
}
parts.emplace_back(input.substr(begin_pos, separator_pos - begin_pos));
return parts;
}
bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides); bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides);
void string_process_escapes(std::string & input); void string_process_escapes(std::string & input);

View file

@ -130,10 +130,12 @@ std::string common_sampler_params::print() const {
snprintf(result, sizeof(result), snprintf(result, sizeof(result),
"\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n" "\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
"\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, temp = %.3f\n" "\tdry_multiplier = %.3f, dry_base = %.3f, dry_allowed_length = %d, dry_penalty_last_n = %d\n"
"\ttop_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, temp = %.3f\n"
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f", "\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
penalty_last_n, penalty_repeat, penalty_freq, penalty_present, penalty_last_n, penalty_repeat, penalty_freq, penalty_present,
top_k, tfs_z, top_p, min_p, xtc_probability, xtc_threshold, typ_p, temp, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n,
top_k, top_p, min_p, xtc_probability, xtc_threshold, typ_p, temp,
mirostat, mirostat_eta, mirostat_tau); mirostat, mirostat_eta, mirostat_tau);
return std::string(result); return std::string(result);
@ -174,6 +176,17 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
if (params.mirostat == 0) { if (params.mirostat == 0) {
for (const auto & cnstr : params.samplers) { for (const auto & cnstr : params.samplers) {
switch (cnstr) { switch (cnstr) {
case COMMON_SAMPLER_TYPE_DRY:
{
std::vector<const char*> c_breakers;
c_breakers.reserve(params.dry_sequence_breakers.size());
for (const auto& str : params.dry_sequence_breakers) {
c_breakers.push_back(str.c_str());
}
llama_sampler_chain_add(result->chain, llama_sampler_init_dry (model, params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
}
break;
case COMMON_SAMPLER_TYPE_TOP_K: case COMMON_SAMPLER_TYPE_TOP_K:
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k)); llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
break; break;
@ -186,9 +199,6 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
case COMMON_SAMPLER_TYPE_XTC: case COMMON_SAMPLER_TYPE_XTC:
llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed)); llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
break; break;
case COMMON_SAMPLER_TYPE_TFS_Z:
llama_sampler_chain_add(result->chain, llama_sampler_init_tail_free(params.tfs_z, params.min_keep));
break;
case COMMON_SAMPLER_TYPE_TYPICAL_P: case COMMON_SAMPLER_TYPE_TYPICAL_P:
llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep)); llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep));
break; break;
@ -358,8 +368,8 @@ std::string common_sampler_prev_str(common_sampler * gsmpl, llama_context * ctx_
char common_sampler_type_to_chr(enum common_sampler_type cnstr) { char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
switch (cnstr) { switch (cnstr) {
case COMMON_SAMPLER_TYPE_DRY: return 'd';
case COMMON_SAMPLER_TYPE_TOP_K: return 'k'; case COMMON_SAMPLER_TYPE_TOP_K: return 'k';
case COMMON_SAMPLER_TYPE_TFS_Z: return 'f';
case COMMON_SAMPLER_TYPE_TYPICAL_P: return 'y'; case COMMON_SAMPLER_TYPE_TYPICAL_P: return 'y';
case COMMON_SAMPLER_TYPE_TOP_P: return 'p'; case COMMON_SAMPLER_TYPE_TOP_P: return 'p';
case COMMON_SAMPLER_TYPE_MIN_P: return 'm'; case COMMON_SAMPLER_TYPE_MIN_P: return 'm';
@ -372,8 +382,8 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
std::string common_sampler_type_to_str(enum common_sampler_type cnstr) { std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
switch (cnstr) { switch (cnstr) {
case COMMON_SAMPLER_TYPE_DRY: return "dry";
case COMMON_SAMPLER_TYPE_TOP_K: return "top_k"; case COMMON_SAMPLER_TYPE_TOP_K: return "top_k";
case COMMON_SAMPLER_TYPE_TFS_Z: return "tfs_z";
case COMMON_SAMPLER_TYPE_TYPICAL_P: return "typ_p"; case COMMON_SAMPLER_TYPE_TYPICAL_P: return "typ_p";
case COMMON_SAMPLER_TYPE_TOP_P: return "top_p"; case COMMON_SAMPLER_TYPE_TOP_P: return "top_p";
case COMMON_SAMPLER_TYPE_MIN_P: return "min_p"; case COMMON_SAMPLER_TYPE_MIN_P: return "min_p";
@ -386,11 +396,11 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
std::vector<common_sampler_type> common_sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names) { std::vector<common_sampler_type> common_sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names) {
std::unordered_map<std::string, common_sampler_type> sampler_canonical_name_map { std::unordered_map<std::string, common_sampler_type> sampler_canonical_name_map {
{ "dry", COMMON_SAMPLER_TYPE_DRY },
{ "top_k", COMMON_SAMPLER_TYPE_TOP_K }, { "top_k", COMMON_SAMPLER_TYPE_TOP_K },
{ "top_p", COMMON_SAMPLER_TYPE_TOP_P }, { "top_p", COMMON_SAMPLER_TYPE_TOP_P },
{ "typ_p", COMMON_SAMPLER_TYPE_TYPICAL_P }, { "typ_p", COMMON_SAMPLER_TYPE_TYPICAL_P },
{ "min_p", COMMON_SAMPLER_TYPE_MIN_P }, { "min_p", COMMON_SAMPLER_TYPE_MIN_P },
{ "tfs_z", COMMON_SAMPLER_TYPE_TFS_Z },
{ "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE }, { "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE },
{ "xtc", COMMON_SAMPLER_TYPE_XTC }, { "xtc", COMMON_SAMPLER_TYPE_XTC },
{ "infill", COMMON_SAMPLER_TYPE_INFILL }, { "infill", COMMON_SAMPLER_TYPE_INFILL },
@ -407,8 +417,6 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
{ "typ-p", COMMON_SAMPLER_TYPE_TYPICAL_P }, { "typ-p", COMMON_SAMPLER_TYPE_TYPICAL_P },
{ "typ", COMMON_SAMPLER_TYPE_TYPICAL_P }, { "typ", COMMON_SAMPLER_TYPE_TYPICAL_P },
{ "min-p", COMMON_SAMPLER_TYPE_MIN_P }, { "min-p", COMMON_SAMPLER_TYPE_MIN_P },
{ "tfs-z", COMMON_SAMPLER_TYPE_TFS_Z },
{ "tfs", COMMON_SAMPLER_TYPE_TFS_Z },
{ "temp", COMMON_SAMPLER_TYPE_TEMPERATURE }, { "temp", COMMON_SAMPLER_TYPE_TEMPERATURE },
}; };
@ -434,8 +442,8 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
std::vector<common_sampler_type> common_sampler_types_from_chars(const std::string & chars) { std::vector<common_sampler_type> common_sampler_types_from_chars(const std::string & chars) {
std::unordered_map<char, common_sampler_type> sampler_name_map = { std::unordered_map<char, common_sampler_type> sampler_name_map = {
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_DRY), COMMON_SAMPLER_TYPE_DRY },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_K), COMMON_SAMPLER_TYPE_TOP_K }, { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_K), COMMON_SAMPLER_TYPE_TOP_K },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TFS_Z), COMMON_SAMPLER_TYPE_TFS_Z },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TYPICAL_P), COMMON_SAMPLER_TYPE_TYPICAL_P }, { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TYPICAL_P), COMMON_SAMPLER_TYPE_TYPICAL_P },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_P), COMMON_SAMPLER_TYPE_TOP_P }, { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_P), COMMON_SAMPLER_TYPE_TOP_P },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_MIN_P), COMMON_SAMPLER_TYPE_MIN_P }, { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_MIN_P), COMMON_SAMPLER_TYPE_MIN_P },

View file

@ -573,6 +573,9 @@ class Model:
if chkhsh == "0876d13b50744004aa9aeae05e7b0647eac9d801b5ba4668afc01e709c15e19f": if chkhsh == "0876d13b50744004aa9aeae05e7b0647eac9d801b5ba4668afc01e709c15e19f":
# ref: https://huggingface.co/BAAI/bge-small-en-v1.5 # ref: https://huggingface.co/BAAI/bge-small-en-v1.5
res = "bert-bge" res = "bert-bge"
if chkhsh == "8e62295832751ca1e8f92f2226f403dea30dc5165e448b5bfa05af5340c64ec7":
# ref: https://huggingface.co/BAAI/bge-large-zh-v1.5
res = "bert-bge-large"
if chkhsh == "b6dc8df998e1cfbdc4eac8243701a65afe638679230920b50d6f17d81c098166": if chkhsh == "b6dc8df998e1cfbdc4eac8243701a65afe638679230920b50d6f17d81c098166":
# ref: https://huggingface.co/mosaicml/mpt-7b # ref: https://huggingface.co/mosaicml/mpt-7b
res = "mpt" res = "mpt"

View file

@ -72,6 +72,7 @@ models = [
{"name": "deepseek-coder", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/deepseek-ai/deepseek-coder-6.7b-base", }, {"name": "deepseek-coder", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/deepseek-ai/deepseek-coder-6.7b-base", },
{"name": "falcon", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/falcon-7b", }, {"name": "falcon", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/falcon-7b", },
{"name": "bert-bge", "tokt": TOKENIZER_TYPE.WPM, "repo": "https://huggingface.co/BAAI/bge-small-en-v1.5", }, {"name": "bert-bge", "tokt": TOKENIZER_TYPE.WPM, "repo": "https://huggingface.co/BAAI/bge-small-en-v1.5", },
{"name": "bert-bge-large", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/BAAI/bge-large-zh-v1.5", },
{"name": "mpt", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/mosaicml/mpt-7b", }, {"name": "mpt", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/mosaicml/mpt-7b", },
{"name": "starcoder", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/bigcode/starcoder2-3b", }, {"name": "starcoder", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/bigcode/starcoder2-3b", },
{"name": "gpt-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/openai-community/gpt2", }, {"name": "gpt-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/openai-community/gpt2", },

View file

@ -40,12 +40,15 @@
repeat_last_n: 0, // 0 = disable penalty, -1 = context size repeat_last_n: 0, // 0 = disable penalty, -1 = context size
repeat_penalty: 1.0, // 1.0 = disabled repeat_penalty: 1.0, // 1.0 = disabled
penalize_nl: false, // true only useful for infinite completion penalize_nl: false, // true only useful for infinite completion
dry_multiplier: 0.0, // 0.0 = disabled, 0.8 works well
dry_base: 1.75, // 0.0 = disabled
dry_allowed_length: 2, // tokens extending repetitions beyond this receive penalty, 2 works well
dry_penalty_last_n: -1, // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size)
top_k: 0, // <= 0 to use vocab size top_k: 0, // <= 0 to use vocab size
top_p: 1.0, // 1.0 = disabled top_p: 1.0, // 1.0 = disabled
min_p: 0.05, // 0 = disabled; recommended for non-english: ~ 0.4 min_p: 0.05, // 0 = disabled; recommended for non-english: ~ 0.4
xtc_probability: 0.0, // 0 = disabled; xtc_probability: 0.0, // 0 = disabled;
xtc_threshold: 0.1, // > 0.5 disables XTC; xtc_threshold: 0.1, // > 0.5 disables XTC;
tfs_z: 1.0, // 1.0 = disabled
typical_p: 1.0, // 1.0 = disabled typical_p: 1.0, // 1.0 = disabled
presence_penalty: 0.0, // 0.0 = disabled presence_penalty: 0.0, // 0.0 = disabled
frequency_penalty: 0.0, // 0.0 = disabled frequency_penalty: 0.0, // 0.0 = disabled
@ -833,13 +836,16 @@ return html`
<fieldset class="params"> <fieldset class="params">
${IntField({ label: "Top-K", title: "Limits the selection of the next token to the K most probable tokens. 1 means no randomness = greedy sampling. If set to 0, it means the entire vocabulary size is considered.", max: 100, min: 0, step: 1, name: "top_k", value: params.value.top_k })} ${IntField({ label: "Top-K", title: "Limits the selection of the next token to the K most probable tokens. 1 means no randomness = greedy sampling. If set to 0, it means the entire vocabulary size is considered.", max: 100, min: 0, step: 1, name: "top_k", value: params.value.top_k })}
${IntField({ label: "Penalize Last N", title: "The last n tokens that are taken into account to penalise repetitions. A value of 0 means that this function is deactivated and -1 means that the entire size of the context is taken into account.", max: 2048, min: 0, step: 16, name: "repeat_last_n", value: params.value.repeat_last_n })} ${IntField({ label: "Penalize Last N", title: "The last n tokens that are taken into account to penalise repetitions. A value of 0 means that this function is deactivated and -1 means that the entire size of the context is taken into account.", max: 2048, min: 0, step: 16, name: "repeat_last_n", value: params.value.repeat_last_n })}
${FloatField({ label: "Top-P", title: "Limits the selection of the next token to a subset of tokens whose combined probability reaches a threshold value P = top-P. If set to 1, it means the entire vocabulary size is considered.", max: 1.0, min: 0.0, name: "top_p", step: 0.01, value: params.value.top_p })}
${FloatField({ label: "Presence Penalty", title: "A penalty that is applied if certain tokens appear repeatedly in the generated text. A higher value leads to fewer repetitions.", max: 1.0, min: 0.0, name: "presence_penalty", step: 0.01, value: params.value.presence_penalty })} ${FloatField({ label: "Presence Penalty", title: "A penalty that is applied if certain tokens appear repeatedly in the generated text. A higher value leads to fewer repetitions.", max: 1.0, min: 0.0, name: "presence_penalty", step: 0.01, value: params.value.presence_penalty })}
${FloatField({ label: "TFS-Z", title: "Activates tail-free sampling, a method used to limit the prediction of tokens that are too frequent. The parameter z controls the strength of this limitation. A value of 1.0 means that this function is deactivated.", max: 1.0, min: 0.0, name: "tfs_z", step: 0.01, value: params.value.tfs_z })}
${FloatField({ label: "Frequency Penalty", title: "A penalty that is applied based on the frequency with which certain tokens occur in the training data set. A higher value results in rare tokens being favoured.", max: 1.0, min: 0.0, name: "frequency_penalty", step: 0.01, value: params.value.frequency_penalty })} ${FloatField({ label: "Frequency Penalty", title: "A penalty that is applied based on the frequency with which certain tokens occur in the training data set. A higher value results in rare tokens being favoured.", max: 1.0, min: 0.0, name: "frequency_penalty", step: 0.01, value: params.value.frequency_penalty })}
${FloatField({ label: "Top-P", title: "Limits the selection of the next token to a subset of tokens whose combined probability reaches a threshold value P = top-P. If set to 1, it means the entire vocabulary size is considered.", max: 1.0, min: 0.0, name: "top_p", step: 0.01, value: params.value.top_p })}
${FloatField({ label: "Typical-P", title: "Activates local typical sampling, a method used to limit the prediction of tokens that are atypical in the current context. The parameter p controls the strength of this limitation. A value of 1.0 means that this function is deactivated.", max: 1.0, min: 0.0, name: "typical_p", step: 0.01, value: params.value.typical_p })} ${FloatField({ label: "Typical-P", title: "Activates local typical sampling, a method used to limit the prediction of tokens that are atypical in the current context. The parameter p controls the strength of this limitation. A value of 1.0 means that this function is deactivated.", max: 1.0, min: 0.0, name: "typical_p", step: 0.01, value: params.value.typical_p })}
${FloatField({ label: "XTC probability", title: "Sets the chance for token removal (checked once on sampler start)", max: 1.0, min: 0.0, name: "xtc_probability", step: 0.01, value: params.value.xtc_probability })} ${FloatField({ label: "XTC probability", title: "Sets the chance for token removal (checked once on sampler start)", max: 1.0, min: 0.0, name: "xtc_probability", step: 0.01, value: params.value.xtc_probability })}
${FloatField({ label: "XTC threshold", title: "Sets a minimum probability threshold for tokens to be removed", max: 0.5, min: 0.0, name: "xtc_threshold", step: 0.01, value: params.value.xtc_threshold })} ${FloatField({ label: "XTC threshold", title: "Sets a minimum probability threshold for tokens to be removed", max: 0.5, min: 0.0, name: "xtc_threshold", step: 0.01, value: params.value.xtc_threshold })}
${FloatField({ label: "DRY Penalty Multiplier", title: "Set the DRY repetition penalty multiplier. Default is 0.0, which disables DRY.", max: 5.0, min: 0.0, name: "dry_multiplier", step: 0.01, value: params.value.dry_multiplier })}
${FloatField({ label: "DRY Base", title: "Set the DRY repetition penalty base value. Default is 1.75", max: 3.0, min: 1.0, name: "dry_base", step: 0.01, value: params.value.dry_base })}
${IntField({ label: "DRY Allowed Length", title: "Tokens that extend repetition beyond this receive exponentially increasing penalty. Default is 2", max: 10, min: 1, step: 1, name: "dry_allowed_length", value: params.value.dry_allowed_length })}
${IntField({ label: "DRY Penalty Last N", title: "How many tokens to scan for repetitions. Default is -1, where 0 is disabled and -1 is context size", max: 2048, min: -1, step: 16, name: "dry_penalty_last_n", value: params.value.dry_penalty_last_n })}
${IntField({ label: "Min Keep", title: "If greater than 0, samplers are forced to return N possible tokens at minimum. Default is 0", max: 10, min: 0, name: "min_keep", value: params.value.min_keep })} ${IntField({ label: "Min Keep", title: "If greater than 0, samplers are forced to return N possible tokens at minimum. Default is 0", max: 10, min: 0, name: "min_keep", value: params.value.min_keep })}
</fieldset> </fieldset>
@ -1139,11 +1145,12 @@ document.addEventListener('DOMContentLoaded', (event) => {
xtc_probability: { snapValue: 0.0, snapRangeMultiplier: 4 }, xtc_probability: { snapValue: 0.0, snapRangeMultiplier: 4 },
xtc_threshold: { snapValue: 0.5, snapRangeMultiplier: 4 }, xtc_threshold: { snapValue: 0.5, snapRangeMultiplier: 4 },
top_p: { snapValue: 1.0, snapRangeMultiplier: 4 }, top_p: { snapValue: 1.0, snapRangeMultiplier: 4 },
tfs_z: { snapValue: 1.0, snapRangeMultiplier: 4 },
typical_p: { snapValue: 1.0, snapRangeMultiplier: 4 }, typical_p: { snapValue: 1.0, snapRangeMultiplier: 4 },
repeat_penalty: { snapValue: 1.0, snapRangeMultiplier: 4 }, repeat_penalty: { snapValue: 1.0, snapRangeMultiplier: 4 },
presence_penalty: { snapValue: 0.0, snapRangeMultiplier: 4 }, presence_penalty: { snapValue: 0.0, snapRangeMultiplier: 4 },
frequency_penalty: { snapValue: 0.0, snapRangeMultiplier: 4 }, frequency_penalty: { snapValue: 0.0, snapRangeMultiplier: 4 },
dry_multiplier: { snapValue: 0.0, snapRangeMultiplier: 4 },
dry_base: { snapValue: 1.75, snapRangeMultiplier: 4 },
}; };
// add an event listener for each slider // add an event listener for each slider
Object.keys(snapSettings).forEach(sliderName => { Object.keys(snapSettings).forEach(sliderName => {

View file

@ -304,12 +304,15 @@
repeat_last_n: 256, // 0 = disable penalty, -1 = context size repeat_last_n: 256, // 0 = disable penalty, -1 = context size
repeat_penalty: 1.18, // 1.0 = disabled repeat_penalty: 1.18, // 1.0 = disabled
penalize_nl: false, penalize_nl: false,
dry_multiplier: 0.0, // 0.0 = disabled, 0.8 works well
dry_base: 1.75, // 0.0 = disabled
dry_allowed_length: 2, // tokens extending repetitions beyond this receive penalty, 2 works well
dry_penalty_last_n: -1, // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size)
top_k: 40, // <= 0 to use vocab size top_k: 40, // <= 0 to use vocab size
top_p: 0.95, // 1.0 = disabled top_p: 0.95, // 1.0 = disabled
min_p: 0.05, // 0 = disabled min_p: 0.05, // 0 = disabled
xtc_probability: 0.0, // 0 = disabled; xtc_probability: 0.0, // 0 = disabled;
xtc_threshold: 0.1, // > 0.5 disables XTC; xtc_threshold: 0.1, // > 0.5 disables XTC;
tfs_z: 1.0, // 1.0 = disabled
typical_p: 1.0, // 1.0 = disabled typical_p: 1.0, // 1.0 = disabled
presence_penalty: 0.0, // 0.0 = disabled presence_penalty: 0.0, // 0.0 = disabled
frequency_penalty: 0.0, // 0.0 = disabled frequency_penalty: 0.0, // 0.0 = disabled
@ -1011,10 +1014,13 @@
<details> <details>
<summary>More options</summary> <summary>More options</summary>
<fieldset class="two"> <fieldset class="two">
${FloatField({ label: "TFS-Z", max: 1.0, min: 0.0, name: "tfs_z", step: 0.01, value: params.value.tfs_z })}
${FloatField({ label: "Typical P", max: 1.0, min: 0.0, name: "typical_p", step: 0.01, value: params.value.typical_p })} ${FloatField({ label: "Typical P", max: 1.0, min: 0.0, name: "typical_p", step: 0.01, value: params.value.typical_p })}
${FloatField({ label: "Presence penalty", max: 1.0, min: 0.0, name: "presence_penalty", step: 0.01, value: params.value.presence_penalty })} ${FloatField({ label: "Presence penalty", max: 1.0, min: 0.0, name: "presence_penalty", step: 0.01, value: params.value.presence_penalty })}
${FloatField({ label: "Frequency penalty", max: 1.0, min: 0.0, name: "frequency_penalty", step: 0.01, value: params.value.frequency_penalty })} ${FloatField({ label: "Frequency penalty", max: 1.0, min: 0.0, name: "frequency_penalty", step: 0.01, value: params.value.frequency_penalty })}
${FloatField({ label: "DRY Penalty Multiplier", max: 5.0, min: 0.0, name: "dry_multiplier", step: 0.01, value: params.value.dry_multiplier })}
${FloatField({ label: "DRY Base", max: 3.0, min: 1.0, name: "dry_base", step: 0.01, value: params.value.dry_base })}
${IntField({ label: "DRY Allowed Length", max: 10, min: 2, step: 1, name: "dry_allowed_length", value: params.value.dry_allowed_length })}
${IntField({ label: "DRY Penalty Last N", max: 2048, min: -1, step: 16, name: "dry_penalty_last_n", value: params.value.dry_penalty_last_n })}
${FloatField({ label: "XTC probability", max: 1.0, min: 0.0, name: "xtc_probability", step: 0.01, value: params.value.xtc_probability })} ${FloatField({ label: "XTC probability", max: 1.0, min: 0.0, name: "xtc_probability", step: 0.01, value: params.value.xtc_probability })}
${FloatField({ label: "XTC threshold", max: 0.5, min: 0.0, name: "xtc_threshold", step: 0.01, value: params.value.xtc_threshold })} ${FloatField({ label: "XTC threshold", max: 0.5, min: 0.0, name: "xtc_threshold", step: 0.01, value: params.value.xtc_threshold })}
</fieldset> </fieldset>

0
examples/server/public/style.css Executable file → Normal file
View file

View file

@ -44,21 +44,6 @@
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
#define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
#define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
#define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
#define SRV_INF(fmt, ...) LOG_INF("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define SRV_WRN(fmt, ...) LOG_WRN("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define SRV_ERR(fmt, ...) LOG_ERR("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define SRV_DBG(fmt, ...) LOG_DBG("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define QUE_INF(fmt, ...) LOG_INF("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define QUE_WRN(fmt, ...) LOG_WRN("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define QUE_ERR(fmt, ...) LOG_ERR("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define QUE_DBG(fmt, ...) LOG_DBG("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
using json = nlohmann::ordered_json; using json = nlohmann::ordered_json;
enum stop_type { enum stop_type {
@ -69,6 +54,7 @@ enum stop_type {
// state diagram: https://github.com/ggerganov/llama.cpp/pull/9283 // state diagram: https://github.com/ggerganov/llama.cpp/pull/9283
enum slot_state { enum slot_state {
SLOT_STATE_IDLE, SLOT_STATE_IDLE,
SLOT_STATE_STARTED, // TODO: this state is only used for setting up the initial prompt processing; maybe merge it with launch_slot_with_task in the future
SLOT_STATE_PROCESSING_PROMPT, SLOT_STATE_PROCESSING_PROMPT,
SLOT_STATE_DONE_PROMPT, SLOT_STATE_DONE_PROMPT,
SLOT_STATE_GENERATING, SLOT_STATE_GENERATING,
@ -80,7 +66,7 @@ enum server_state {
}; };
enum server_task_type { enum server_task_type {
SERVER_TASK_TYPE_COMPLETION, SERVER_TASK_TYPE_INFERENCE,
SERVER_TASK_TYPE_CANCEL, SERVER_TASK_TYPE_CANCEL,
SERVER_TASK_TYPE_NEXT_RESPONSE, SERVER_TASK_TYPE_NEXT_RESPONSE,
SERVER_TASK_TYPE_METRICS, SERVER_TASK_TYPE_METRICS,
@ -90,21 +76,22 @@ enum server_task_type {
SERVER_TASK_TYPE_SET_LORA, SERVER_TASK_TYPE_SET_LORA,
}; };
enum server_task_cmpl_type { enum server_task_inf_type {
SERVER_TASK_CMPL_TYPE_NORMAL, SERVER_TASK_INF_TYPE_COMPLETION,
SERVER_TASK_CMPL_TYPE_EMBEDDING, SERVER_TASK_INF_TYPE_EMBEDDING,
SERVER_TASK_CMPL_TYPE_RERANK, SERVER_TASK_INF_TYPE_RERANK,
SERVER_TASK_CMPL_TYPE_INFILL, SERVER_TASK_INF_TYPE_INFILL,
}; };
struct server_task { struct server_task {
int id = -1; // to be filled by server_queue int id = -1; // to be filled by server_queue
int id_target = -1; // used by SERVER_TASK_TYPE_CANCEL int id_target = -1; // used by SERVER_TASK_TYPE_CANCEL
llama_tokens prompt_tokens;
server_task_type type; server_task_type type;
json data; json data;
server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL; server_task_inf_type inf_type = SERVER_TASK_INF_TYPE_COMPLETION;
// utility function // utility function
static std::unordered_set<int> get_list_id(const std::vector<server_task> & tasks) { static std::unordered_set<int> get_list_id(const std::vector<server_task> & tasks) {
@ -162,26 +149,20 @@ struct server_slot {
int32_t i_batch = -1; int32_t i_batch = -1;
int32_t n_predict = -1; // TODO: disambiguate from params.n_predict int32_t n_predict = -1; // TODO: disambiguate from params.n_predict
// n_prompt_tokens may not be equal to prompt_tokens.size(), because prompt maybe truncated
int32_t n_prompt_tokens = 0; int32_t n_prompt_tokens = 0;
int32_t n_prompt_tokens_processed = 0; int32_t n_prompt_tokens_processed = 0;
json prompt; // can be either a string, array of strings or array of token ids // input prompt tokens
llama_tokens prompt_tokens;
json input_prefix;
json input_suffix;
json input_extra;
// when a task is submitted, we first tokenize the prompt and store it here
std::vector<llama_token> prompt_tokens;
std::vector<llama_token> extra_tokens;
size_t last_nl_pos = 0; size_t last_nl_pos = 0;
std::string generated_text; std::string generated_text;
std::vector<llama_token> cache_tokens; llama_tokens cache_tokens;
std::vector<completion_token_output> generated_token_probs; std::vector<completion_token_output> generated_token_probs;
server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL; server_task_inf_type inf_type = SERVER_TASK_INF_TYPE_COMPLETION;
bool has_next_token = true; bool has_next_token = true;
bool has_new_line = false; bool has_new_line = false;
@ -230,7 +211,7 @@ struct server_slot {
n_past = 0; n_past = 0;
n_sent_text = 0; n_sent_text = 0;
n_sent_token_probs = 0; n_sent_token_probs = 0;
cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL; inf_type = SERVER_TASK_INF_TYPE_COMPLETION;
generated_token_probs.clear(); generated_token_probs.clear();
} }
@ -735,42 +716,6 @@ struct server_context {
metrics.init(); metrics.init();
} }
std::vector<llama_token> tokenize(const json & json_prompt, bool add_special, bool parse_special) const {
// If `add_bos` is true, we only add BOS, when json_prompt is a string,
// or the first element of the json_prompt array is a string.
std::vector<llama_token> prompt_tokens;
if (json_prompt.is_array()) {
bool first = true;
for (const auto & p : json_prompt) {
if (p.is_string()) {
auto s = p.template get<std::string>();
std::vector<llama_token> p;
if (first) {
p = common_tokenize(ctx, s, add_special, parse_special);
first = false;
} else {
p = common_tokenize(ctx, s, false, parse_special);
}
prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end());
} else {
if (first) {
first = false;
}
prompt_tokens.push_back(p.template get<llama_token>());
}
}
} else {
auto s = json_prompt.template get<std::string>();
prompt_tokens = common_tokenize(ctx, s, add_special, parse_special);
}
return prompt_tokens;
}
server_slot * get_slot_by_id(int id) { server_slot * get_slot_by_id(int id) {
for (server_slot & slot : slots) { for (server_slot & slot : slots) {
if (slot.id == id) { if (slot.id == id) {
@ -795,22 +740,16 @@ struct server_context {
continue; continue;
} }
// skip the slot if it does not contains prompt // skip the slot if it does not contains cached tokens
if (!slot.prompt.is_string()) { if (slot.prompt_tokens.empty()) {
continue; continue;
} }
// current slot's prompt
std::string slot_prompt = slot.prompt.get<std::string>();
// length of the current slot's prompt
int slot_prompt_len = slot_prompt.size();
// length of the Longest Common Prefix between the current slot's prompt and the input prompt // length of the Longest Common Prefix between the current slot's prompt and the input prompt
int lcp_len = longest_common_prefix(slot_prompt, prompt); int lcp_len = longest_common_prefix(slot.cache_tokens, slot.prompt_tokens);
// fraction of the common substring length compared to the current slot's prompt length // fraction of the common substring length compared to the current slot's prompt length
similarity = static_cast<float>(lcp_len) / slot_prompt_len; similarity = static_cast<float>(lcp_len) / static_cast<int>(slot.prompt_tokens.size());
// select the current slot if the criteria match // select the current slot if the criteria match
if (lcp_len > max_lcp_len && similarity > slot_prompt_similarity) { if (lcp_len > max_lcp_len && similarity > slot_prompt_similarity) {
@ -862,35 +801,57 @@ struct server_context {
slot.oaicompat_model = ""; slot.oaicompat_model = "";
} }
slot.params.stream = json_value(data, "stream", false); slot.params.stream = json_value(data, "stream", false);
slot.params.cache_prompt = json_value(data, "cache_prompt", false); slot.params.cache_prompt = json_value(data, "cache_prompt", false);
slot.params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", default_params.n_predict)); slot.params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", default_params.n_predict));
slot.params.n_indent = json_value(data, "n_indent", default_params.n_indent); slot.params.n_indent = json_value(data, "n_indent", default_params.n_indent);
slot.sparams.top_k = json_value(data, "top_k", default_sparams.top_k); slot.sparams.top_k = json_value(data, "top_k", default_sparams.top_k);
slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p); slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p);
slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p); slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p);
slot.sparams.xtc_probability = json_value(data, "xtc_probability", default_sparams.xtc_probability); slot.sparams.xtc_probability = json_value(data, "xtc_probability", default_sparams.xtc_probability);
slot.sparams.xtc_threshold = json_value(data, "xtc_threshold", default_sparams.xtc_threshold); slot.sparams.xtc_threshold = json_value(data, "xtc_threshold", default_sparams.xtc_threshold);
slot.sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z); slot.sparams.typ_p = json_value(data, "typical_p", default_sparams.typ_p);
slot.sparams.typ_p = json_value(data, "typical_p", default_sparams.typ_p); slot.sparams.temp = json_value(data, "temperature", default_sparams.temp);
slot.sparams.temp = json_value(data, "temperature", default_sparams.temp); slot.sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range);
slot.sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range); slot.sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent);
slot.sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent); slot.sparams.penalty_last_n = json_value(data, "repeat_last_n", default_sparams.penalty_last_n);
slot.sparams.penalty_last_n = json_value(data, "repeat_last_n", default_sparams.penalty_last_n); slot.sparams.penalty_repeat = json_value(data, "repeat_penalty", default_sparams.penalty_repeat);
slot.sparams.penalty_repeat = json_value(data, "repeat_penalty", default_sparams.penalty_repeat); slot.sparams.penalty_freq = json_value(data, "frequency_penalty", default_sparams.penalty_freq);
slot.sparams.penalty_freq = json_value(data, "frequency_penalty", default_sparams.penalty_freq); slot.sparams.penalty_present = json_value(data, "presence_penalty", default_sparams.penalty_present);
slot.sparams.penalty_present = json_value(data, "presence_penalty", default_sparams.penalty_present); slot.sparams.dry_multiplier = json_value(data, "dry_multiplier", default_sparams.dry_multiplier);
slot.sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat); slot.sparams.dry_base = json_value(data, "dry_base", default_sparams.dry_base);
slot.sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau); slot.sparams.dry_allowed_length = json_value(data, "dry_allowed_length", default_sparams.dry_allowed_length);
slot.sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta); slot.sparams.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", default_sparams.dry_penalty_last_n);
slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl); slot.sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat);
slot.params.n_keep = json_value(data, "n_keep", default_params.n_keep); slot.sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau);
slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard); slot.sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta);
slot.sparams.seed = json_value(data, "seed", default_sparams.seed); slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs); slot.params.n_keep = json_value(data, "n_keep", default_params.n_keep);
slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep); slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard);
//slot.params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", default_params.t_max_prompt_ms); // TODO: implement slot.sparams.seed = json_value(data, "seed", default_sparams.seed);
slot.params.t_max_predict_ms = json_value(data, "t_max_predict_ms", default_params.t_max_predict_ms); slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
//slot.params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", default_params.t_max_prompt_ms); // TODO: implement
slot.params.t_max_predict_ms = json_value(data, "t_max_predict_ms", default_params.t_max_predict_ms);
if (slot.sparams.dry_base < 1.0f)
{
slot.sparams.dry_base = default_sparams.dry_base;
}
// sequence breakers for DRY
{
// Currently, this is not compatible with TextGen WebUI, Koboldcpp and SillyTavern format
// Ref: https://github.com/oobabooga/text-generation-webui/blob/d1af7a41ade7bd3c3a463bfa640725edb818ebaf/extensions/openai/typing.py#L39
if (data.contains("dry_sequence_breakers")) {
slot.sparams.dry_sequence_breakers = json_value(data, "dry_sequence_breakers", std::vector<std::string>());
if (slot.sparams.dry_sequence_breakers.empty()) {
send_error(task, "Error: dry_sequence_breakers must be a non-empty array of strings", ERROR_TYPE_INVALID_REQUEST);
return false;
}
}
}
// process "json_schema" and "grammar" // process "json_schema" and "grammar"
if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) { if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
@ -915,57 +876,6 @@ struct server_context {
SLT_WRN(slot, "n_predict = %d exceeds server configuration, setting to %d", slot.n_predict, slot.n_predict); SLT_WRN(slot, "n_predict = %d exceeds server configuration, setting to %d", slot.n_predict, slot.n_predict);
} }
// infill
slot.input_prefix = json_value(data, "input_prefix", json());
slot.input_suffix = json_value(data, "input_suffix", json());
slot.input_extra = json_value(data, "input_extra", json());
SLT_DBG(slot, "extra_context chunks: %d\n", (int) slot.input_extra.size());
for (const auto & chunk : slot.input_extra) {
// { "text": string, "filename": string }
if (!chunk.contains("text") || !chunk["text"].is_string()) {
send_error(task, "extra_context chunk must contain a \"text\" field with a string value", ERROR_TYPE_INVALID_REQUEST);
return false;
}
// filename is optional
if (chunk.contains("filename") && !chunk["filename"].is_string()) {
send_error(task, "extra_context chunk's \"filename\" field must be a string", ERROR_TYPE_INVALID_REQUEST);
return false;
}
SLT_DBG(slot, "extra_context chunk in file '%s':\n%s\n", chunk.value("filename", "").c_str(), chunk.value("text", "").c_str());
}
// get prompt
{
const auto & prompt = data.find("prompt");
if (prompt == data.end()) {
send_error(task, "\"prompt\" must be provided", ERROR_TYPE_INVALID_REQUEST);
return false;
}
if ((prompt->is_string()) ||
(prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_string()) ||
(prompt->is_array() && !prompt->empty() && prompt->at(0).is_number_integer())) {
slot.prompt = *prompt;
} else if (prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_array()) {
slot.prompt = prompt->at(0);
} else if (prompt->is_array() && prompt->size() > 1) {
// array of strings
for (const auto & el : *prompt) {
if (!el.is_string()) {
send_error(task, "\"prompt\" must be a string, an array of strings or an array of integers", ERROR_TYPE_INVALID_REQUEST);
return false;
}
}
slot.prompt = *prompt;
} else {
send_error(task, "\"prompt\" must be a string, an array of strings or an array of integers", ERROR_TYPE_INVALID_REQUEST);
return false;
}
}
{ {
slot.sparams.logit_bias.clear(); slot.sparams.logit_bias.clear();
@ -1045,8 +955,7 @@ struct server_context {
} }
} }
slot.state = SLOT_STATE_PROCESSING_PROMPT; slot.state = SLOT_STATE_STARTED;
slot.prompt_tokens.clear();
SLT_INF(slot, "%s", "processing task\n"); SLT_INF(slot, "%s", "processing task\n");
@ -1240,12 +1149,16 @@ struct server_context {
{"min_p", slot.sparams.min_p}, {"min_p", slot.sparams.min_p},
{"xtc_probability", slot.sparams.xtc_probability}, {"xtc_probability", slot.sparams.xtc_probability},
{"xtc_threshold", slot.sparams.xtc_threshold}, {"xtc_threshold", slot.sparams.xtc_threshold},
{"tfs_z", slot.sparams.tfs_z},
{"typical_p", slot.sparams.typ_p}, {"typical_p", slot.sparams.typ_p},
{"repeat_last_n", slot.sparams.penalty_last_n}, {"repeat_last_n", slot.sparams.penalty_last_n},
{"repeat_penalty", slot.sparams.penalty_repeat}, {"repeat_penalty", slot.sparams.penalty_repeat},
{"presence_penalty", slot.sparams.penalty_present}, {"presence_penalty", slot.sparams.penalty_present},
{"frequency_penalty", slot.sparams.penalty_freq}, {"frequency_penalty", slot.sparams.penalty_freq},
{"dry_multiplier", slot.sparams.dry_multiplier},
{"dry_base", slot.sparams.dry_base},
{"dry_allowed_length", slot.sparams.dry_allowed_length},
{"dry_penalty_last_n", slot.sparams.dry_penalty_last_n},
{"dry_sequence_breakers", slot.sparams.dry_sequence_breakers},
{"mirostat", slot.sparams.mirostat}, {"mirostat", slot.sparams.mirostat},
{"mirostat_tau", slot.sparams.mirostat_tau}, {"mirostat_tau", slot.sparams.mirostat_tau},
{"mirostat_eta", slot.sparams.mirostat_eta}, {"mirostat_eta", slot.sparams.mirostat_eta},
@ -1298,7 +1211,7 @@ struct server_context {
}; };
if (slot.sparams.n_probs > 0) { if (slot.sparams.n_probs > 0) {
const std::vector<llama_token> to_send_toks = common_tokenize(ctx, tkn.text_to_send, false); const llama_tokens to_send_toks = common_tokenize(ctx, tkn.text_to_send, false);
const size_t probs_pos = std::min(slot.n_sent_token_probs, slot.generated_token_probs.size()); const size_t probs_pos = std::min(slot.n_sent_token_probs, slot.generated_token_probs.size());
const size_t probs_stop_pos = std::min(slot.n_sent_token_probs + to_send_toks.size(), slot.generated_token_probs.size()); const size_t probs_stop_pos = std::min(slot.n_sent_token_probs + to_send_toks.size(), slot.generated_token_probs.size());
@ -1334,7 +1247,7 @@ struct server_context {
{"tokens_predicted", slot.n_decoded}, {"tokens_predicted", slot.n_decoded},
{"tokens_evaluated", slot.n_prompt_tokens}, {"tokens_evaluated", slot.n_prompt_tokens},
{"generation_settings", get_formated_generation(slot)}, {"generation_settings", get_formated_generation(slot)},
{"prompt", slot.prompt}, {"prompt", common_detokenize(ctx, slot.prompt_tokens)},
{"has_new_line", slot.has_new_line}, {"has_new_line", slot.has_new_line},
{"truncated", slot.truncated}, {"truncated", slot.truncated},
{"stopped_eos", slot.stopped_eos}, {"stopped_eos", slot.stopped_eos},
@ -1349,7 +1262,7 @@ struct server_context {
if (slot.sparams.n_probs > 0) { if (slot.sparams.n_probs > 0) {
std::vector<completion_token_output> probs; std::vector<completion_token_output> probs;
if (!slot.params.stream && slot.stopped_word) { if (!slot.params.stream && slot.stopped_word) {
const std::vector<llama_token> stop_word_toks = common_tokenize(ctx, slot.stopping_word, false); const llama_tokens stop_word_toks = common_tokenize(ctx, slot.stopping_word, false);
size_t safe_offset = std::min(slot.generated_token_probs.size(), stop_word_toks.size()); size_t safe_offset = std::min(slot.generated_token_probs.size(), stop_word_toks.size());
probs = std::vector<completion_token_output>( probs = std::vector<completion_token_output>(
@ -1458,19 +1371,17 @@ struct server_context {
// Functions to create new task(s) and receive result(s) // Functions to create new task(s) and receive result(s)
// //
std::vector<server_task> create_tasks_cmpl(json data, server_task_cmpl_type cmpl_type) { // break the input "prompt" into multiple tasks if needed, then format and tokenize the input prompt(s)
std::vector<server_task> create_tasks_inference(json data, server_task_inf_type inf_type) {
std::vector<server_task> tasks; std::vector<server_task> tasks;
auto create_task = [&](json & task_data, bool replace_prompt, json prompt) { auto create_task = [&](json & task_data, llama_tokens & prompt_tokens) {
SRV_DBG("create task, n_tokens = %d\n", (int) prompt_tokens.size());
server_task task; server_task task;
task.id = queue_tasks.get_new_id(); task.id = queue_tasks.get_new_id();
task.cmpl_type = cmpl_type; task.inf_type = inf_type;
task.type = SERVER_TASK_TYPE_COMPLETION; task.type = SERVER_TASK_TYPE_INFERENCE;
if (replace_prompt) { task.data = task_data;
task.data = task_data; task.prompt_tokens = std::move(prompt_tokens);
task.data["prompt"] = std::move(prompt);
} else {
task.data = std::move(task_data);
}
tasks.push_back(std::move(task)); tasks.push_back(std::move(task));
}; };
@ -1479,41 +1390,49 @@ struct server_context {
throw std::runtime_error(error_msg); throw std::runtime_error(error_msg);
} }
json prompt = data.at("prompt"); // because llama_tokenize api is thread-safe, we can tokenize the prompt from HTTP thread
bool add_special = inf_type != SERVER_TASK_INF_TYPE_RERANK && inf_type != SERVER_TASK_INF_TYPE_INFILL;
// if the prompt is a singleton (i.e. a string or a list of tokens), we only need to create single task std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx, data.at("prompt"), add_special, true);
if (prompt.is_string() || json_is_array_of_numbers(prompt)) { switch (inf_type) {
data["index"] = 0; case SERVER_TASK_INF_TYPE_RERANK:
create_task(data, false, nullptr); {
} else if (prompt.is_array()) { // prompts[0] is the question
// otherwise, it's a multiple-prompt task, we break it into smaller tasks // the rest are the answers/documents
std::vector<json> prompts = prompt; GGML_ASSERT(tokenized_prompts.size() > 1);
if (cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) { SRV_DBG("creating rerank tasks, n_prompts = %d\n", (int) tokenized_prompts.size() - 1);
// prompts[0] is the question for (size_t i = 1; i < tokenized_prompts.size(); i++) {
// the rest are the answers/documents data["index"] = i - 1;
SRV_DBG("creating rerank tasks, n_prompts = %d\n", (int) prompts.size() - 1); auto tokens = format_rerank(model, tokenized_prompts[0], tokenized_prompts[i]);
for (size_t i = 1; i < prompts.size(); i++) { create_task(data, tokens);
json qd; }
qd.push_back(prompts[0]); } break;
qd.push_back(prompts[i]); case SERVER_TASK_INF_TYPE_INFILL:
data["index"] = i - 1; {
create_task(data, true, qd); SRV_DBG("creating infill tasks, n_prompts = %d\n", (int) tokenized_prompts.size());
} for (size_t i = 0; i < tokenized_prompts.size(); i++) {
} else {
SRV_DBG("creating multi-prompt tasks, n_prompts = %d\n", (int) prompts.size());
for (size_t i = 0; i < prompts.size(); i++) {
const auto & e = prompts[i];
if (e.is_string() || json_is_array_of_numbers(e)) {
data["index"] = i; data["index"] = i;
create_task(data, true, e); auto tokens = format_infill(
} else { ctx,
throw std::runtime_error(error_msg); data.at("input_prefix"),
data.at("input_suffix"),
data.at("input_extra"),
params.n_batch,
params.n_predict,
slots[0].n_ctx, // TODO: there should be a better way
params.spm_infill,
tokenized_prompts[i]
);
create_task(data, tokens);
}
} break;
default:
{
SRV_DBG("creating multi-prompt tasks, n_prompts = %d\n", (int) tokenized_prompts.size());
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
data["index"] = i;
create_task(data, tokenized_prompts[i]);
} }
} }
}
} else {
// invalid case
throw std::runtime_error(error_msg);
} }
return tasks; return tasks;
@ -1535,7 +1454,7 @@ struct server_context {
queue_tasks.post(cancel_tasks, true); queue_tasks.post(cancel_tasks, true);
} }
// receive the results from task(s) created by create_tasks_cmpl // receive the results from task(s) created by create_tasks_inference
void receive_cmpl_results( void receive_cmpl_results(
const std::unordered_set<int> & id_tasks, const std::unordered_set<int> & id_tasks,
const std::function<void(std::vector<server_task_result>&)> & result_handler, const std::function<void(std::vector<server_task_result>&)> & result_handler,
@ -1559,7 +1478,7 @@ struct server_context {
result_handler(results); result_handler(results);
} }
// receive the results from task(s) created by create_tasks_cmpl, in stream mode // receive the results from task(s) created by create_tasks_inference, in stream mode
void receive_cmpl_results_stream( void receive_cmpl_results_stream(
const std::unordered_set<int> & id_tasks, const const std::unordered_set<int> & id_tasks, const
std::function<bool(server_task_result&)> & result_handler, const std::function<bool(server_task_result&)> & result_handler, const
@ -1592,7 +1511,7 @@ struct server_context {
void process_single_task(const server_task & task) { void process_single_task(const server_task & task) {
switch (task.type) { switch (task.type) {
case SERVER_TASK_TYPE_COMPLETION: case SERVER_TASK_TYPE_INFERENCE:
{ {
const int id_slot = json_value(task.data, "id_slot", -1); const int id_slot = json_value(task.data, "id_slot", -1);
@ -1624,9 +1543,10 @@ struct server_context {
slot->reset(); slot->reset();
slot->id_task = task.id; slot->id_task = task.id;
slot->cmpl_type = task.cmpl_type; slot->inf_type = task.inf_type;
slot->index = json_value(task.data, "index", 0); slot->index = json_value(task.data, "index", 0);
slot->prompt_tokens = std::move(task.prompt_tokens);
if (!launch_slot_with_task(*slot, task)) { if (!launch_slot_with_task(*slot, task)) {
SRV_ERR("failed to launch slot with task, id_task = %d\n", task.id); SRV_ERR("failed to launch slot with task, id_task = %d\n", task.id);
@ -1659,7 +1579,7 @@ struct server_context {
slot_data["id"] = slot.id; slot_data["id"] = slot.id;
slot_data["id_task"] = slot.id_task; slot_data["id_task"] = slot.id_task;
slot_data["state"] = slot.state; slot_data["state"] = slot.state;
slot_data["prompt"] = slot.prompt; slot_data["prompt"] = common_detokenize(ctx, slot.prompt_tokens);
slot_data["next_token"] = { slot_data["next_token"] = {
{"has_next_token", slot.has_next_token}, {"has_next_token", slot.has_next_token},
{"has_new_line", slot.has_new_line}, {"has_new_line", slot.has_new_line},
@ -1786,9 +1706,6 @@ struct server_context {
} }
slot->cache_tokens.resize(token_count); slot->cache_tokens.resize(token_count);
// TODO: maybe detokenize the slot->cache_tokens instead?
slot->prompt = string_format("[restored %d tokens from file]", (int) token_count);
const int64_t t_end = ggml_time_us(); const int64_t t_end = ggml_time_us();
const double t_restore_ms = (t_end - t_start) / 1000.0; const double t_restore_ms = (t_end - t_start) / 1000.0;
@ -1955,142 +1872,19 @@ struct server_context {
if (params.cont_batching || batch.n_tokens == 0) { if (params.cont_batching || batch.n_tokens == 0) {
for (auto & slot : slots) { for (auto & slot : slots) {
// this slot still has a prompt to be processed // this slot still has a prompt to be processed
if (slot.state == SLOT_STATE_PROCESSING_PROMPT) { if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) {
auto & prompt_tokens = slot.prompt_tokens; auto & prompt_tokens = slot.prompt_tokens;
// we haven't tokenized the prompt yet - do it now: // TODO: maybe move branch to outside of this loop in the future
if (prompt_tokens.empty()) { if (slot.state == SLOT_STATE_STARTED) {
SLT_INF(slot, "tokenizing prompt, len = %d\n", (int) slot.prompt.size());
slot.t_start_process_prompt = ggml_time_us(); slot.t_start_process_prompt = ggml_time_us();
slot.t_start_generation = 0; slot.t_start_generation = 0;
switch (slot.cmpl_type) {
case SERVER_TASK_CMPL_TYPE_NORMAL:
case SERVER_TASK_CMPL_TYPE_EMBEDDING:
{
prompt_tokens = tokenize(slot.prompt, llama_add_bos_token(model), true);
} break;
case SERVER_TASK_CMPL_TYPE_RERANK:
{
// require slot.prompt to be array of 2 strings
if (!slot.prompt.is_array() || slot.prompt.size() != 2) {
SLT_ERR(slot, "%s", "invalid prompt for rerank task\n");
slot.release();
send_error(slot, "invalid prompt for rerank task", ERROR_TYPE_INVALID_REQUEST);
continue;
}
// prompt: [BOS]query[EOS][SEP]doc[EOS]
prompt_tokens.clear();
prompt_tokens.push_back(llama_token_bos(model));
{
const auto part = tokenize(slot.prompt[0], false, false);
prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end());
}
prompt_tokens.push_back(llama_token_eos(model));
prompt_tokens.push_back(llama_token_sep(model));
{
const auto part = tokenize(slot.prompt[1], false, false);
prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end());
}
prompt_tokens.push_back(llama_token_eos(model));
} break;
case SERVER_TASK_CMPL_TYPE_INFILL:
{
// TODO: optimize this block by reducing memory allocations and movement
// use FIM repo-level pattern:
// ref: https://arxiv.org/pdf/2409.12186
//
// [FIM_REP]myproject
// [FIM_SEP]filename0
// extra chunk 0
// [FIM_SEP]filename1
// extra chunk 1
// ...
// [FIM_SEP]filename
// [FIM_PRE]prefix[FIM_SUF]suffix[FIM_MID]prompt
//
auto tokens_prefix = tokenize(slot.input_prefix, false, false);
auto tokens_suffix = tokenize(slot.input_suffix, false, false);
auto tokens_prompt = tokenize(slot.prompt, false, false);
slot.extra_tokens.clear();
if (llama_token_fim_rep(model) != LLAMA_TOKEN_NULL) {
static const auto k_fim_repo = tokenize("myproject\n", false, false);
slot.extra_tokens.push_back(llama_token_fim_rep(model));
slot.extra_tokens.insert(slot.extra_tokens.end(), k_fim_repo.begin(), k_fim_repo.end());
}
for (const auto & chunk : slot.input_extra) {
// { "text": string, "filename": string }
const std::string text = chunk.value("text", "");
const std::string filename = chunk.value("filename", "tmp");
if (llama_token_fim_sep(model) != LLAMA_TOKEN_NULL) {
const auto k_fim_file = tokenize(filename + "\n", false, false);
slot.extra_tokens.insert(slot.extra_tokens.end(), llama_token_fim_sep(model));
slot.extra_tokens.insert(slot.extra_tokens.end(), k_fim_file.begin(), k_fim_file.end());
} else {
// chunk separator in binary form to avoid confusing the AI
static const char k_chunk_prefix_str[] = {0x0a, 0x0a, 0x2d, 0x2d, 0x2d, 0x20, 0x73, 0x6e, 0x69, 0x70, 0x70, 0x65, 0x74, 0x20, 0x2d, 0x2d, 0x2d, 0x0a, 0x0a, 0x00};
static const auto k_chunk_prefix_tokens = tokenize(k_chunk_prefix_str, false, false);
slot.extra_tokens.insert(slot.extra_tokens.end(), k_chunk_prefix_tokens.begin(), k_chunk_prefix_tokens.end());
}
const auto chunk_tokens = tokenize(text, false, false);
slot.extra_tokens.insert(slot.extra_tokens.end(), chunk_tokens.begin(), chunk_tokens.end());
}
if (llama_token_fim_sep(model) != LLAMA_TOKEN_NULL) {
// TODO: current filename
static const auto k_fim_file = tokenize("filename\n", false, false);
slot.extra_tokens.insert(slot.extra_tokens.end(), llama_token_fim_sep(model));
slot.extra_tokens.insert(slot.extra_tokens.end(), k_fim_file.begin(), k_fim_file.end());
}
// for now pick FIM context to fit in a batch (ratio prefix:suffix = 3:1, TODO: configurable?)
const int n_suffix_take = std::min<int>(tokens_suffix.size(), (n_batch/4));
const int n_prefix_take = std::min<int>(tokens_prefix.size(), 3*(n_batch/4) - 3);
// fill the rest of the context with extra chunks
const int n_extra_take = std::min<int>(std::max<int>(0, slot.n_ctx - (n_batch) - 2*slot.n_predict), slot.extra_tokens.size());
tokens_prefix.erase(tokens_prefix.begin(), tokens_prefix.begin() + tokens_prefix.size() - n_prefix_take);
tokens_suffix.resize(n_suffix_take);
tokens_prefix.insert(tokens_prefix.begin(), llama_token_fim_pre(model));
tokens_prefix.insert(tokens_prefix.end(), tokens_prompt.begin(), tokens_prompt.end());
tokens_suffix.insert(tokens_suffix.begin(), llama_token_fim_suf(model));
auto embd_inp = params.spm_infill ? tokens_suffix : tokens_prefix;
auto embd_end = params.spm_infill ? tokens_prefix : tokens_suffix;
if (llama_add_bos_token(model)) {
embd_inp.insert(embd_inp.begin(), llama_token_bos(model));
}
SLT_DBG(slot, "extra: n_ctx = %d, n_extra_take = %d, n_extra = %d\n", slot.n_ctx, n_extra_take, (int) slot.extra_tokens.size());
// put the extra context before the FIM prefix
embd_inp.insert(embd_inp.begin(), slot.extra_tokens.end() - n_extra_take, slot.extra_tokens.end());
embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end());
embd_inp.push_back(llama_token_fim_mid(model));
prompt_tokens = std::move(embd_inp);
} break;
}
slot.n_past = 0; slot.n_past = 0;
slot.n_prompt_tokens = prompt_tokens.size(); slot.n_prompt_tokens = prompt_tokens.size();
slot.state = SLOT_STATE_PROCESSING_PROMPT;
SLT_INF(slot, "prompt tokenized, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, slot.n_prompt_tokens); SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, slot.n_prompt_tokens);
// print prompt tokens (for debugging) // print prompt tokens (for debugging)
if (1) { if (1) {
@ -2115,13 +1909,18 @@ struct server_context {
continue; continue;
} }
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) { if (slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING || slot.inf_type == SERVER_TASK_INF_TYPE_RERANK) {
// this prompt is too large to process - discard it
if (slot.n_prompt_tokens > n_ubatch) { if (slot.n_prompt_tokens > n_ubatch) {
slot.release(); slot.release();
send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER); send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER);
continue; continue;
} }
if (slot.n_prompt_tokens > slot.n_ctx) {
slot.release();
send_error(slot, "input is larger than the max context size. skipping", ERROR_TYPE_SERVER);
continue;
}
} else { } else {
if (!params.ctx_shift) { if (!params.ctx_shift) {
// if context shift is disabled, we make sure prompt size is smaller than KV size // if context shift is disabled, we make sure prompt size is smaller than KV size
@ -2145,7 +1944,7 @@ struct server_context {
const int n_block_size = n_left / 2; const int n_block_size = n_left / 2;
const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size; const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size;
std::vector<llama_token> new_tokens( llama_tokens new_tokens(
prompt_tokens.begin(), prompt_tokens.begin(),
prompt_tokens.begin() + slot.params.n_keep); prompt_tokens.begin() + slot.params.n_keep);
@ -2199,7 +1998,6 @@ struct server_context {
for (size_t i = 0; i < n_match; i++) { for (size_t i = 0; i < n_match; i++) {
slot.cache_tokens[head_p + i] = slot.cache_tokens[head_c + i]; slot.cache_tokens[head_p + i] = slot.cache_tokens[head_c + i];
slot.n_past++; slot.n_past++;
} }
@ -2226,7 +2024,7 @@ struct server_context {
} }
// non-causal tasks require to fit the entire prompt in the physical batch // non-causal tasks require to fit the entire prompt in the physical batch
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) { if (slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING || slot.inf_type == SERVER_TASK_INF_TYPE_RERANK) {
// cannot fit the prompt in the current batch - will try next iter // cannot fit the prompt in the current batch - will try next iter
if (batch.n_tokens + slot.n_prompt_tokens > n_batch) { if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
continue; continue;
@ -2235,8 +2033,8 @@ struct server_context {
// check that we are in the right batch_type, if not defer the slot // check that we are in the right batch_type, if not defer the slot
const bool slot_type = const bool slot_type =
slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING ||
slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK ? 1 : 0; slot.inf_type == SERVER_TASK_INF_TYPE_RERANK ? 1 : 0;
if (batch_type == -1) { if (batch_type == -1) {
batch_type = slot_type; batch_type = slot_type;
@ -2354,7 +2152,7 @@ struct server_context {
} }
if (slot.state == SLOT_STATE_DONE_PROMPT) { if (slot.state == SLOT_STATE_DONE_PROMPT) {
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING) { if (slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING) {
// prompt evaluated for embedding // prompt evaluated for embedding
send_embedding(slot, batch_view); send_embedding(slot, batch_view);
slot.release(); slot.release();
@ -2362,7 +2160,7 @@ struct server_context {
continue; // continue loop of slots continue; // continue loop of slots
} }
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) { if (slot.inf_type == SERVER_TASK_INF_TYPE_RERANK) {
send_rerank(slot, batch_view); send_rerank(slot, batch_view);
slot.release(); slot.release();
slot.i_batch = -1; slot.i_batch = -1;
@ -2609,7 +2407,7 @@ int main(int argc, char ** argv) {
auto middleware_server_state = [&res_error, &state](const httplib::Request & req, httplib::Response & res) { auto middleware_server_state = [&res_error, &state](const httplib::Request & req, httplib::Response & res) {
server_state current_state = state.load(); server_state current_state = state.load();
if (current_state == SERVER_STATE_LOADING_MODEL) { if (current_state == SERVER_STATE_LOADING_MODEL) {
auto tmp = string_split(req.path, '.'); auto tmp = string_split<std::string>(req.path, '.');
if (req.path == "/" || tmp.back() == "html") { if (req.path == "/" || tmp.back() == "html") {
res.set_content(reinterpret_cast<const char*>(loading_html), loading_html_len, "text/html; charset=utf-8"); res.set_content(reinterpret_cast<const char*>(loading_html), loading_html_len, "text/html; charset=utf-8");
res.status = 503; res.status = 503;
@ -2916,13 +2714,13 @@ int main(int argc, char ** argv) {
res_ok(res, {{ "success", true }}); res_ok(res, {{ "success", true }});
}; };
const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_cmpl_type cmpl_type, json & data, httplib::Response & res) { const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_inf_type inf_type, json & data, httplib::Response & res) {
if (ctx_server.params.embedding || ctx_server.params.reranking) { if (ctx_server.params.embedding || ctx_server.params.reranking) {
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED)); res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
return; return;
} }
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl(data, cmpl_type); std::vector<server_task> tasks = ctx_server.create_tasks_inference(data, inf_type);
ctx_server.queue_results.add_waiting_tasks(tasks); ctx_server.queue_results.add_waiting_tasks(tasks);
ctx_server.queue_tasks.post(tasks); ctx_server.queue_tasks.post(tasks);
@ -2968,10 +2766,11 @@ int main(int argc, char ** argv) {
const auto handle_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) { const auto handle_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
json data = json::parse(req.body); json data = json::parse(req.body);
return handle_completions_generic(SERVER_TASK_CMPL_TYPE_NORMAL, data, res); return handle_completions_generic(SERVER_TASK_INF_TYPE_COMPLETION, data, res);
}; };
const auto handle_infill = [&ctx_server, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) { const auto handle_infill = [&ctx_server, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
// check model compatibility
std::string err; std::string err;
if (llama_token_fim_pre(ctx_server.model) == LLAMA_TOKEN_NULL) { if (llama_token_fim_pre(ctx_server.model) == LLAMA_TOKEN_NULL) {
err += "prefix token is missing. "; err += "prefix token is missing. ";
@ -2982,14 +2781,42 @@ int main(int argc, char ** argv) {
if (llama_token_fim_mid(ctx_server.model) == LLAMA_TOKEN_NULL) { if (llama_token_fim_mid(ctx_server.model) == LLAMA_TOKEN_NULL) {
err += "middle token is missing. "; err += "middle token is missing. ";
} }
if (!err.empty()) { if (!err.empty()) {
res_error(res, format_error_response(string_format("Infill is not supported by this model: %s", err.c_str()), ERROR_TYPE_NOT_SUPPORTED)); res_error(res, format_error_response(string_format("Infill is not supported by this model: %s", err.c_str()), ERROR_TYPE_NOT_SUPPORTED));
return; return;
} }
json data = json::parse(req.body); json data = json::parse(req.body);
return handle_completions_generic(SERVER_TASK_CMPL_TYPE_INFILL, data, res);
// validate input
if (!data.contains("input_prefix")) {
res_error(res, format_error_response("\"input_prefix\" is required", ERROR_TYPE_INVALID_REQUEST));
}
if (!data.contains("input_suffix")) {
res_error(res, format_error_response("\"input_suffix\" is required", ERROR_TYPE_INVALID_REQUEST));
}
if (data.contains("input_extra") && !data.at("input_extra").is_array()) {
res_error(res, format_error_response("\"input_extra\" must be an array of {\"filename\": string, \"text\": string}", ERROR_TYPE_INVALID_REQUEST));
return;
}
json input_extra = json_value(data, "input_extra", json::array());
for (const auto & chunk : input_extra) {
// { "text": string, "filename": string }
if (!chunk.contains("text") || !chunk.at("text").is_string()) {
res_error(res, format_error_response("extra_context chunk must contain a \"text\" field with a string value", ERROR_TYPE_INVALID_REQUEST));
return;
}
// filename is optional
if (chunk.contains("filename") && !chunk.at("filename").is_string()) {
res_error(res, format_error_response("extra_context chunk's \"filename\" field must be a string", ERROR_TYPE_INVALID_REQUEST));
return;
}
}
data["input_extra"] = input_extra; // default to empty array if it's not exist
return handle_completions_generic(SERVER_TASK_INF_TYPE_INFILL, data, res);
}; };
// TODO: maybe merge this function with "handle_completions_generic" // TODO: maybe merge this function with "handle_completions_generic"
@ -3001,7 +2828,7 @@ int main(int argc, char ** argv) {
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template); json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl(data, SERVER_TASK_CMPL_TYPE_NORMAL); std::vector<server_task> tasks = ctx_server.create_tasks_inference(data, SERVER_TASK_INF_TYPE_COMPLETION);
ctx_server.queue_results.add_waiting_tasks(tasks); ctx_server.queue_results.add_waiting_tasks(tasks);
ctx_server.queue_tasks.post(tasks); ctx_server.queue_tasks.post(tasks);
@ -3074,7 +2901,7 @@ int main(int argc, char ** argv) {
const bool add_special = json_value(body, "add_special", false); const bool add_special = json_value(body, "add_special", false);
const bool with_pieces = json_value(body, "with_pieces", false); const bool with_pieces = json_value(body, "with_pieces", false);
std::vector<llama_token> tokens = ctx_server.tokenize(body.at("content"), add_special, true); llama_tokens tokens = tokenize_mixed(ctx_server.ctx, body.at("content"), add_special, true);
if (with_pieces) { if (with_pieces) {
for (const auto& token : tokens) { for (const auto& token : tokens) {
@ -3111,7 +2938,7 @@ int main(int argc, char ** argv) {
std::string content; std::string content;
if (body.count("tokens") != 0) { if (body.count("tokens") != 0) {
const std::vector<llama_token> tokens = body.at("tokens"); const llama_tokens tokens = body.at("tokens");
content = tokens_to_str(ctx_server.ctx, tokens.cbegin(), tokens.cend()); content = tokens_to_str(ctx_server.ctx, tokens.cbegin(), tokens.cend());
} }
@ -3145,7 +2972,7 @@ int main(int argc, char ** argv) {
json responses = json::array(); json responses = json::array();
bool error = false; bool error = false;
{ {
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_EMBEDDING); std::vector<server_task> tasks = ctx_server.create_tasks_inference({{"prompt", prompt}}, SERVER_TASK_INF_TYPE_EMBEDDING);
ctx_server.queue_results.add_waiting_tasks(tasks); ctx_server.queue_results.add_waiting_tasks(tasks);
ctx_server.queue_tasks.post(tasks); ctx_server.queue_tasks.post(tasks);
@ -3222,7 +3049,7 @@ int main(int argc, char ** argv) {
json responses = json::array(); json responses = json::array();
bool error = false; bool error = false;
{ {
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_RERANK); std::vector<server_task> tasks = ctx_server.create_tasks_inference({{"prompt", prompt}}, SERVER_TASK_INF_TYPE_RERANK);
ctx_server.queue_results.add_waiting_tasks(tasks); ctx_server.queue_results.add_waiting_tasks(tasks);
ctx_server.queue_tasks.post(tasks); ctx_server.queue_tasks.post(tasks);

View file

@ -0,0 +1,36 @@
@llama.cpp
@infill
Feature: llama.cpp server
# The current model is made by adding FIM tokens to the existing stories260K
# We may want to use a better model in the future, maybe something like SmolLM 360M
Background: Server startup
Given a server listening on localhost:8080
And a model file tinyllamas/stories260K-infill.gguf from HF repo ggml-org/models
And a model file test-model-infill.gguf
And a model alias tinyllama-infill
And 42 as server seed
And 1024 as batch size
And 1024 as ubatch size
And 2048 KV cache size
And 64 max tokens to predict
And 0.0 temperature
Then the server is starting
Then the server is healthy
Scenario: Infill without input_extra
Given a prompt "Complete this"
And an infill input extra none none
And an infill input prefix "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n int n_threads = llama_"
And an infill input suffix "}\n"
And an infill request with no api error
Then 64 tokens are predicted matching One|day|she|saw|big|scary|bird
Scenario: Infill with input_extra
Given a prompt "Complete this"
And an infill input extra "llama.h" "LLAMA_API int32_t llama_n_threads();\n"
And an infill input prefix "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n int n_threads = llama_"
And an infill input suffix "}\n"
And an infill request with no api error
Then 64 tokens are predicted matching cuts|Jimmy|mom|came|into|the|room"

View file

@ -80,6 +80,11 @@ def step_server_config(context, server_fqdn: str, server_port: str):
context.lora_file = None context.lora_file = None
context.disable_ctx_shift = False context.disable_ctx_shift = False
# infill
context.infill_input_extra = None
context.infill_input_suffix = ''
context.infill_input_prefix = ''
context.tasks_result = [] context.tasks_result = []
context.concurrent_tasks = [] context.concurrent_tasks = []
context.prompts = [] context.prompts = []
@ -291,6 +296,28 @@ async def step_request_completion(context, api_error: Literal['raised'] | str):
assert completion == api_error_code, f"completion must be an {api_error_code} status code: {completion}" assert completion == api_error_code, f"completion must be an {api_error_code} status code: {completion}"
@step('an infill request with {api_error} api error')
@async_run_until_complete
async def step_request_completion(context, api_error: Literal['raised'] | str):
if api_error != 'no':
raise ValueError(f'api_error={api_error} is not yet implemented')
payload = {
"prompt": context.prompts[0],
"input_suffix": context.infill_input_suffix,
"input_prefix": context.infill_input_prefix,
"n_predict": context.n_predict,
"seed": context.seed,
"temperature": context.temperature,
}
if context.infill_input_extra is not None:
payload['input_extra'] = context.infill_input_extra
async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session:
async with session.post(f'{context.base_url}/infill',
json=payload) as response:
assert response.status == 200
context.tasks_result = [await response.json()]
@step('{predicted_n:d} tokens are predicted matching {re_content}') @step('{predicted_n:d} tokens are predicted matching {re_content}')
def step_n_tokens_predicted_with_content(context, predicted_n, re_content): def step_n_tokens_predicted_with_content(context, predicted_n, re_content):
context.completion = context.tasks_result.pop() context.completion = context.tasks_result.pop()
@ -539,6 +566,25 @@ def step_a_prompt_prompt(context, prompt):
context.n_prompts = len(context.prompts) context.n_prompts = len(context.prompts)
# TODO: allow this to be repeated
@step('an infill input extra {filename} {text}')
def step_infill_input_extra(context, filename, text):
if filename == 'none':
context.infill_input_extra = None
else:
context.infill_input_extra = [{'filename': filename, 'text': text}]
@step('an infill input suffix {text}')
def step_infill_input_suffix(context, text):
context.infill_input_suffix = text
@step('an infill input prefix {text}')
def step_infill_input_prefix(context, text):
context.infill_input_prefix = text
@step('{num_prompts:d} prompts {prompt} with seed {seed:d}') @step('{num_prompts:d} prompts {prompt} with seed {seed:d}')
def step_many_prompts(context, num_prompts, prompt, seed): def step_many_prompts(context, num_prompts, prompt, seed):
if context.seed is None: if context.seed is None:

View file

@ -226,7 +226,6 @@
top_k: 40, // <= 0 to use vocab size top_k: 40, // <= 0 to use vocab size
top_p: 0.95, // 1.0 = disabled top_p: 0.95, // 1.0 = disabled
min_p: 0.05, // 0 = disabled min_p: 0.05, // 0 = disabled
tfs_z: 1.0, // 1.0 = disabled
typical_p: 1.0, // 1.0 = disabled typical_p: 1.0, // 1.0 = disabled
presence_penalty: 0.0, // 0.0 = disabled presence_penalty: 0.0, // 0.0 = disabled
frequency_penalty: 0.0, // 0.0 = disabled frequency_penalty: 0.0, // 0.0 = disabled
@ -788,7 +787,6 @@
<details> <details>
<summary>More options</summary> <summary>More options</summary>
<fieldset class="two"> <fieldset class="two">
${FloatField({ label: "TFS-Z", max: 1.0, min: 0.0, name: "tfs_z", step: 0.01, value: params.value.tfs_z })}
${FloatField({ label: "Typical P", max: 1.0, min: 0.0, name: "typical_p", step: 0.01, value: params.value.typical_p })} ${FloatField({ label: "Typical P", max: 1.0, min: 0.0, name: "typical_p", step: 0.01, value: params.value.typical_p })}
${FloatField({ label: "Presence penalty", max: 1.0, min: 0.0, name: "presence_penalty", step: 0.01, value: params.value.presence_penalty })} ${FloatField({ label: "Presence penalty", max: 1.0, min: 0.0, name: "presence_penalty", step: 0.01, value: params.value.presence_penalty })}
${FloatField({ label: "Frequency penalty", max: 1.0, min: 0.0, name: "frequency_penalty", step: 0.01, value: params.value.frequency_penalty })} ${FloatField({ label: "Frequency penalty", max: 1.0, min: 0.0, name: "frequency_penalty", step: 0.01, value: params.value.frequency_penalty })}

View file

@ -229,7 +229,6 @@
top_k: 40, // <= 0 to use vocab size top_k: 40, // <= 0 to use vocab size
top_p: 0.95, // 1.0 = disabled top_p: 0.95, // 1.0 = disabled
min_p: 0.05, // 0 = disabled min_p: 0.05, // 0 = disabled
tfs_z: 1.0, // 1.0 = disabled
typical_p: 1.0, // 1.0 = disabled typical_p: 1.0, // 1.0 = disabled
presence_penalty: 0.0, // 0.0 = disabled presence_penalty: 0.0, // 0.0 = disabled
frequency_penalty: 0.0, // 0.0 = disabled frequency_penalty: 0.0, // 0.0 = disabled
@ -791,7 +790,6 @@
<details> <details>
<summary>More options</summary> <summary>More options</summary>
<fieldset class="two"> <fieldset class="two">
${FloatField({ label: "TFS-Z", max: 1.0, min: 0.0, name: "tfs_z", step: 0.01, value: params.value.tfs_z })}
${FloatField({ label: "Typical P", max: 1.0, min: 0.0, name: "typical_p", step: 0.01, value: params.value.typical_p })} ${FloatField({ label: "Typical P", max: 1.0, min: 0.0, name: "typical_p", step: 0.01, value: params.value.typical_p })}
${FloatField({ label: "Presence penalty", max: 1.0, min: 0.0, name: "presence_penalty", step: 0.01, value: params.value.presence_penalty })} ${FloatField({ label: "Presence penalty", max: 1.0, min: 0.0, name: "presence_penalty", step: 0.01, value: params.value.presence_penalty })}
${FloatField({ label: "Frequency penalty", max: 1.0, min: 0.0, name: "frequency_penalty", step: 0.01, value: params.value.frequency_penalty })} ${FloatField({ label: "Frequency penalty", max: 1.0, min: 0.0, name: "frequency_penalty", step: 0.01, value: params.value.frequency_penalty })}

View file

@ -24,6 +24,22 @@
#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613" #define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613"
using json = nlohmann::ordered_json; using json = nlohmann::ordered_json;
using llama_tokens = std::vector<llama_token>;
#define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
#define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
#define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
#define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
#define SRV_INF(fmt, ...) LOG_INF("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define SRV_WRN(fmt, ...) LOG_WRN("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define SRV_ERR(fmt, ...) LOG_ERR("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define SRV_DBG(fmt, ...) LOG_DBG("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define QUE_INF(fmt, ...) LOG_INF("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define QUE_WRN(fmt, ...) LOG_WRN("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define QUE_ERR(fmt, ...) LOG_ERR("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define QUE_DBG(fmt, ...) LOG_DBG("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11 // https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11
enum error_type { enum error_type {
@ -52,9 +68,237 @@ static T json_value(const json & body, const std::string & key, const T & defaul
} }
// //
// chat template utils // tokenizer and input processing utils
// //
static bool json_is_array_of_numbers(const json & data) {
if (data.is_array()) {
for (const auto & e : data) {
if (!e.is_number_integer()) {
return false;
}
}
return true;
}
return false;
}
// is array having BOTH numbers & strings?
static bool json_is_array_of_mixed_numbers_strings(const json & data) {
bool seen_string = false;
bool seen_number = false;
if (data.is_array()) {
for (const auto & e : data) {
seen_string |= e.is_string();
seen_number |= e.is_number_integer();
if (seen_number && seen_string) {
return true;
}
}
}
return false;
}
/**
* this handles 2 cases:
* - only string, example: "string"
* - mixed string and tokens, example: [12, 34, "string", 56, 78]
*/
static llama_tokens tokenize_mixed(const llama_context * ctx, const json & json_prompt, bool add_special, bool parse_special) {
// If `add_bos` is true, we only add BOS, when json_prompt is a string,
// or the first element of the json_prompt array is a string.
llama_tokens prompt_tokens;
if (json_prompt.is_array()) {
bool first = true;
for (const auto & p : json_prompt) {
if (p.is_string()) {
auto s = p.template get<std::string>();
llama_tokens p;
if (first) {
p = common_tokenize(ctx, s, add_special, parse_special);
first = false;
} else {
p = common_tokenize(ctx, s, false, parse_special);
}
prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end());
} else {
if (first) {
first = false;
}
prompt_tokens.push_back(p.template get<llama_token>());
}
}
} else {
auto s = json_prompt.template get<std::string>();
prompt_tokens = common_tokenize(ctx, s, add_special, parse_special);
}
return prompt_tokens;
}
/**
* break the input "prompt" object into multiple prompt if needed, then tokenize them
* this supports these cases:
* - "prompt": "string"
* - "prompt": [12, 34, 56]
* - "prompt": [12, 34, "string", 56, 78]
* and multiple prompts (multi-tasks):
* - "prompt": ["string1", "string2"]
* - "prompt": ["string1", [12, 34, 56]]
* - "prompt": [[12, 34, "string", 56, 78], [12, 34, 56]]
*/
static std::vector<llama_tokens> tokenize_input_prompts(llama_context * ctx, const json & json_prompt, bool add_special, bool parse_special) {
std::vector<llama_tokens> result;
if (json_prompt.is_string() || json_is_array_of_mixed_numbers_strings(json_prompt)) {
// string or mixed
result.push_back(tokenize_mixed(ctx, json_prompt, add_special, parse_special));
} else if (json_is_array_of_numbers(json_prompt)) {
// array of tokens
result.push_back(json_prompt.get<llama_tokens>());
} else if (json_prompt.is_array()) {
// array of prompts
result.reserve(json_prompt.size());
for (const auto & p : json_prompt) {
if (p.is_string() || json_is_array_of_mixed_numbers_strings(p)) {
result.push_back(tokenize_mixed(ctx, p, add_special, parse_special));
} else if (json_is_array_of_numbers(p)) {
// array of tokens
result.push_back(p.get<llama_tokens>());
} else {
throw std::runtime_error("element of \"prompt\" must be a string, an list of tokens, or a list of mixed strings & tokens");
}
}
} else {
throw std::runtime_error("\"prompt\" must be a string, an list of tokens, a list of mixed strings & tokens, or a list of prompts");
}
return result;
}
//
// template utils
//
// format rerank task: [BOS]query[EOS][SEP]doc[EOS]
static llama_tokens format_rerank(const struct llama_model * model, const llama_tokens & query, const llama_tokens & doc) {
llama_tokens result;
result.reserve(doc.size() + query.size() + 4);
result.push_back(llama_token_bos(model));
result.insert(result.end(), query.begin(), query.end());
result.push_back(llama_token_eos(model));
result.push_back(llama_token_sep(model));
result.insert(result.end(), doc.begin(), doc.end());
result.push_back(llama_token_eos(model));
return result;
}
// format infill task
static llama_tokens format_infill(
const llama_context * ctx,
const json & input_prefix,
const json & input_suffix,
const json & input_extra,
const int n_batch,
const int n_predict,
const int n_ctx,
const bool spm_infill,
const llama_tokens & tokens_prompt
) {
// TODO: optimize this block by reducing memory allocations and movement
// use FIM repo-level pattern:
// ref: https://arxiv.org/pdf/2409.12186
//
// [FIM_REP]myproject
// [FIM_SEP]filename0
// extra chunk 0
// [FIM_SEP]filename1
// extra chunk 1
// ...
// [FIM_SEP]filename
// [FIM_PRE]prefix[FIM_SUF]suffix[FIM_MID]prompt
//
llama_tokens extra_tokens;
extra_tokens.reserve(n_ctx);
auto model = llama_get_model(ctx);
auto tokens_prefix = tokenize_mixed(ctx, input_prefix, false, false);
auto tokens_suffix = tokenize_mixed(ctx, input_suffix, false, false);
if (llama_token_fim_rep(model) != LLAMA_TOKEN_NULL) {
// TODO: make project name an input
static const auto k_fim_repo = common_tokenize(ctx, "myproject\n", false, false);
extra_tokens.push_back(llama_token_fim_rep(model));
extra_tokens.insert(extra_tokens.end(), k_fim_repo.begin(), k_fim_repo.end());
}
for (const auto & chunk : input_extra) {
// { "text": string, "filename": string }
const std::string text = json_value(chunk, "text", std::string());
const std::string filename = json_value(chunk, "filename", std::string("tmp"));
if (llama_token_fim_sep(model) != LLAMA_TOKEN_NULL) {
const auto k_fim_file = common_tokenize(ctx, filename + "\n", false, false);
extra_tokens.insert(extra_tokens.end(), llama_token_fim_sep(model));
extra_tokens.insert(extra_tokens.end(), k_fim_file.begin(), k_fim_file.end());
} else {
// chunk separator in binary form to avoid confusing the AI
static const char k_chunk_prefix_str[] = {0x0a, 0x0a, 0x2d, 0x2d, 0x2d, 0x20, 0x73, 0x6e, 0x69, 0x70, 0x70, 0x65, 0x74, 0x20, 0x2d, 0x2d, 0x2d, 0x0a, 0x0a, 0x00};
static const auto k_chunk_prefix_tokens = common_tokenize(ctx, k_chunk_prefix_str, false, false);
extra_tokens.insert(extra_tokens.end(), k_chunk_prefix_tokens.begin(), k_chunk_prefix_tokens.end());
}
const auto chunk_tokens = common_tokenize(ctx, text, false, false);
extra_tokens.insert(extra_tokens.end(), chunk_tokens.begin(), chunk_tokens.end());
}
if (llama_token_fim_sep(model) != LLAMA_TOKEN_NULL) {
// TODO: current filename
static const auto k_fim_file = common_tokenize(ctx, "filename\n", false, false);
extra_tokens.insert(extra_tokens.end(), llama_token_fim_sep(model));
extra_tokens.insert(extra_tokens.end(), k_fim_file.begin(), k_fim_file.end());
}
// for now pick FIM context to fit in a batch (ratio prefix:suffix = 3:1, TODO: configurable?)
const int n_prefix_take = std::min<int>(tokens_prefix.size(), 3*(n_batch/4));
const int n_suffix_take = std::min<int>(tokens_suffix.size(), std::max<int>(0, (n_batch/4) - (2 + tokens_prompt.size())));
SRV_DBG("n_prefix_take = %d, n_suffix_take = %d, total = %d\n", n_prefix_take, n_suffix_take, (n_prefix_take + n_suffix_take));
// fill the rest of the context with extra chunks
const int n_extra_take = std::min<int>(std::max<int>(0, n_ctx - (n_batch) - 2*n_predict), extra_tokens.size());
tokens_prefix.erase(tokens_prefix.begin(), tokens_prefix.begin() + tokens_prefix.size() - n_prefix_take);
tokens_suffix.resize(n_suffix_take);
tokens_prefix.insert(tokens_prefix.begin(), llama_token_fim_pre(model));
tokens_prefix.insert(tokens_prefix.end(), tokens_prompt.begin(), tokens_prompt.end());
tokens_suffix.insert(tokens_suffix.begin(), llama_token_fim_suf(model));
auto embd_inp = spm_infill ? tokens_suffix : tokens_prefix;
auto embd_end = spm_infill ? tokens_prefix : tokens_suffix;
if (llama_add_bos_token(model)) {
embd_inp.insert(embd_inp.begin(), llama_token_bos(model));
}
SRV_DBG("extra: n_ctx = %d, n_extra_take = %d, n_extra = %d\n", n_ctx, n_extra_take, (int) extra_tokens.size());
// put the extra context before the FIM prefix
embd_inp.insert(embd_inp.begin(), extra_tokens.end() - n_extra_take, extra_tokens.end());
embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end());
embd_inp.push_back(llama_token_fim_mid(model));
return embd_inp;
}
// Format given chat. If tmpl is empty, we take the template from model metadata // Format given chat. If tmpl is empty, we take the template from model metadata
inline std::string format_chat(const struct llama_model * model, const std::string & tmpl, const std::vector<json> & messages) { inline std::string format_chat(const struct llama_model * model, const std::string & tmpl, const std::vector<json> & messages) {
std::vector<common_chat_msg> chat; std::vector<common_chat_msg> chat;
@ -229,18 +473,6 @@ static size_t find_partial_stop_string(const std::string &stop, const std::strin
return std::string::npos; return std::string::npos;
} }
static bool json_is_array_of_numbers(const json & data) {
if (data.is_array()) {
for (const auto & e : data) {
if (!e.is_number()) {
return false;
}
}
return true;
}
return false;
}
// TODO: reuse llama_detokenize // TODO: reuse llama_detokenize
template <class Iter> template <class Iter>
static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) { static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) {
@ -375,7 +607,7 @@ static json oaicompat_completion_params_parse(
} }
// Copy remaining properties to llama_params // Copy remaining properties to llama_params
// This allows user to use llama.cpp-specific params like "mirostat", "tfs_z",... via OAI endpoint. // This allows user to use llama.cpp-specific params like "mirostat", ... via OAI endpoint.
// See "launch_slot_with_task()" for a complete list of params supported by llama.cpp // See "launch_slot_with_task()" for a complete list of params supported by llama.cpp
for (const auto & item : body.items()) { for (const auto & item : body.items()) {
// Exception: if "n_predict" is present, we overwrite the value specified earlier by "max_tokens" // Exception: if "n_predict" is present, we overwrite the value specified earlier by "max_tokens"

View file

@ -1484,14 +1484,19 @@ static void ggml_cuda_op_mul_mat(
const size_t nbytes_data = ggml_nbytes(src0); const size_t nbytes_data = ggml_nbytes(src0);
const size_t nbytes_padding = ggml_row_size(src0->type, MATRIX_ROW_PADDING - ne00 % MATRIX_ROW_PADDING); const size_t nbytes_padding = ggml_row_size(src0->type, MATRIX_ROW_PADDING - ne00 % MATRIX_ROW_PADDING);
dev[id].src0_dd = dev[id].src0_dd_alloc.alloc(ctx.pool(id), nbytes_data + nbytes_padding); dev[id].src0_dd = dev[id].src0_dd_alloc.alloc(ctx.pool(id), nbytes_data + nbytes_padding);
// TODO: remove this for MUSA once the Guilty Lockup issue is resolved
#ifndef GGML_USE_MUSA
CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd, 0, nbytes_data + nbytes_padding, stream)); CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd, 0, nbytes_data + nbytes_padding, stream));
#else // GGML_USE_MUSA
CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd + nbytes_data, 0, nbytes_padding, stream));
#endif // !GGML_USE_MUSA
} }
// If src0 is on a temporary compute buffer (partial offloading) there may be some padding that needs to be cleared: // If src0 is on a temporary compute buffer (partial offloading) there may be some padding that needs to be cleared:
if (ne00 % MATRIX_ROW_PADDING != 0 && ggml_is_quantized(src0->type) && ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE && src0->view_src == nullptr) { if (ne00 % MATRIX_ROW_PADDING != 0 && ggml_is_quantized(src0->type) && ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE && src0->view_src == nullptr) {
const size_t nbytes_data = ggml_row_size(src0->type, (dev[id].row_high - dev[id].row_low)*ne00); const size_t nbytes_data = ggml_row_size(src0->type, (dev[id].row_high - dev[id].row_low)*ne00);
const size_t nbytes_padding = ggml_row_size(src0->type, MATRIX_ROW_PADDING - ne00 % MATRIX_ROW_PADDING); const size_t nbytes_padding = ggml_row_size(src0->type, MATRIX_ROW_PADDING - ne00 % MATRIX_ROW_PADDING);
CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd + nbytes_data , 0, nbytes_padding, stream)); CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd + nbytes_data, 0, nbytes_padding, stream));
} }
if (src1_on_device && src1_is_contiguous) { if (src1_on_device && src1_is_contiguous) {

View file

@ -1,6 +1,6 @@
#include "common.cuh" #include "common.cuh"
#define CUDA_CPY_BLOCK_SIZE 32 #define CUDA_CPY_BLOCK_SIZE 64
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1); void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1);

View file

@ -1015,19 +1015,21 @@ static void ggml_metal_encode_node(
id<MTLBuffer> id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil; id<MTLBuffer> id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil;
id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil; id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil;
//GGML_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op)); #if 0
//if (src0) { GGML_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
// GGML_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02, if (src0) {
// ggml_is_contiguous(src0), src0->name); GGML_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03,
//} ggml_is_contiguous(src0), src0->name);
//if (src1) { }
// GGML_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12, if (src1) {
// ggml_is_contiguous(src1), src1->name); GGML_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13,
//} ggml_is_contiguous(src1), src1->name);
//if (dst) { }
// GGML_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2, if (dst) {
// dst->name); GGML_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3,
//} dst->name);
}
#endif
id<MTLDevice> device = ctx_dev->mtl_device; id<MTLDevice> device = ctx_dev->mtl_device;
@ -1810,14 +1812,16 @@ static void ggml_metal_encode_node(
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5]; [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6]; [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7]; [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:7];
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8]; [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:8];
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9]; [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:9];
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10]; [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:10];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11]; [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:11];
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12]; [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:12];
[encoder setBytes:&r2 length:sizeof(r2) atIndex:13]; [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
[encoder setBytes:&r3 length:sizeof(r3) atIndex:14]; [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
[encoder setBytes:&r2 length:sizeof(r2) atIndex:15];
[encoder setBytes:&r3 length:sizeof(r3) atIndex:16];
[encoder setThreadgroupMemoryLength:8192 atIndex:0]; [encoder setThreadgroupMemoryLength:8192 atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
} else { } else {
@ -1986,20 +1990,22 @@ static void ggml_metal_encode_node(
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9]; [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10]; [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:11]; [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:12]; [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:13]; [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:13];
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14]; [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:14];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15]; [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:15];
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16]; [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:16];
[encoder setBytes:&r2 length:sizeof(r2) atIndex:17]; [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
[encoder setBytes:&r3 length:sizeof(r3) atIndex:18]; [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:18];
[encoder setBytes:&r2 length:sizeof(r2) atIndex:19];
[encoder setBytes:&r3 length:sizeof(r3) atIndex:20];
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 || if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) { src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} }
else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) { else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
@ -2048,6 +2054,9 @@ static void ggml_metal_encode_node(
GGML_ASSERT(src1t == GGML_TYPE_F32); GGML_ASSERT(src1t == GGML_TYPE_F32);
GGML_ASSERT(ne03 == 1);
GGML_ASSERT(ne13 == 1);
// find the break-even point where the matrix-matrix kernel becomes more efficient compared // find the break-even point where the matrix-matrix kernel becomes more efficient compared
// to the matrix-vector kernel // to the matrix-vector kernel
// ne20 = n_used_experts // ne20 = n_used_experts

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -213,6 +213,7 @@ struct vk_device_struct {
vk_pipeline pipeline_sum_rows_f32; vk_pipeline pipeline_sum_rows_f32;
vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16; vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
vk_pipeline pipeline_timestep_embedding_f32; vk_pipeline pipeline_timestep_embedding_f32;
vk_pipeline pipeline_pool2d_f32;
std::unordered_map<std::string, vk_pipeline_ref> pipelines; std::unordered_map<std::string, vk_pipeline_ref> pipelines;
std::unordered_map<std::string, uint64_t> pipeline_descriptor_set_requirements; std::unordered_map<std::string, uint64_t> pipeline_descriptor_set_requirements;
@ -403,6 +404,17 @@ struct vk_op_timestep_embedding_push_constants {
uint32_t max_period; uint32_t max_period;
}; };
struct vk_op_pool2d_push_constants {
uint32_t IW; uint32_t IH;
uint32_t OW; uint32_t OH;
uint32_t OC;
uint32_t pelements;
uint32_t op;
int32_t k0; int32_t k1;
int32_t s0; int32_t s1;
int32_t p0; int32_t p1;
};
// Allow pre-recording command buffers // Allow pre-recording command buffers
struct vk_staging_memcpy { struct vk_staging_memcpy {
vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {} vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
@ -1803,6 +1815,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_pool2d_f32, "pool2d_f32", pool2d_f32_len, pool2d_f32_data, "main", 2, sizeof(vk_op_pool2d_push_constants), {512, 1, 1}, {}, 1);
for (auto &c : compiles) { for (auto &c : compiles) {
c.wait(); c.wait();
} }
@ -4234,6 +4248,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return ctx->device->pipeline_timestep_embedding_f32; return ctx->device->pipeline_timestep_embedding_f32;
} }
return nullptr; return nullptr;
case GGML_OP_POOL_2D:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_pool2d_f32;
}
return nullptr;
case GGML_OP_LEAKY_RELU: case GGML_OP_LEAKY_RELU:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_leaky_relu_f32; return ctx->device->pipeline_leaky_relu_f32;
@ -4464,6 +4483,14 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
uint32_t half_ceil = (dim + 1) / 2; uint32_t half_ceil = (dim + 1) / 2;
elements = { half_ceil, (uint32_t)src0->ne[0], 1 }; elements = { half_ceil, (uint32_t)src0->ne[0], 1 };
} break; } break;
case GGML_OP_POOL_2D:
{
const uint32_t N = dst->ne[3];
const uint32_t OC = dst->ne[2];
const uint32_t OH = dst->ne[1];
const uint32_t OW = dst->ne[0];
elements = { N * OC * OH * OW, 1, 1};
} break;
case GGML_OP_ADD: case GGML_OP_ADD:
case GGML_OP_DIV: case GGML_OP_DIV:
case GGML_OP_MUL: case GGML_OP_MUL:
@ -4914,6 +4941,34 @@ static void ggml_vk_timestep_embedding(ggml_backend_vk_context * ctx, vk_context
}, dryrun); }, dryrun);
} }
static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
uint32_t op = static_cast<uint32_t>(dst->op_params[0]);
const int32_t k1 = dst->op_params[1];
const int32_t k0 = dst->op_params[2];
const int32_t s1 = dst->op_params[3];
const int32_t s0 = dst->op_params[4];
const int32_t p1 = dst->op_params[5];
const int32_t p0 = dst->op_params[6];
const uint32_t IH = src0->ne[1];
const uint32_t IW = src0->ne[0];
const uint32_t N = dst->ne[3];
const uint32_t OC = dst->ne[2];
const uint32_t OH = dst->ne[1];
const uint32_t OW = dst->ne[0];
const uint32_t parallel_elements = N * OC * OH * OW;
ggml_vk_op_f32<vk_op_pool2d_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_POOL_2D, {
IW, IH, OW, OH, OC,
parallel_elements,
op,
k0, k1, s0, s1, p0, p1,
}, dryrun);
}
static void ggml_vk_leaky_relu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { static void ggml_vk_leaky_relu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
const float * op_params = (const float *)dst->op_params; const float * op_params = (const float *)dst->op_params;
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f }, dryrun); ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f }, dryrun);
@ -5792,6 +5847,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
case GGML_OP_SUM_ROWS: case GGML_OP_SUM_ROWS:
case GGML_OP_IM2COL: case GGML_OP_IM2COL:
case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_POOL_2D:
case GGML_OP_LEAKY_RELU: case GGML_OP_LEAKY_RELU:
break; break;
default: default:
@ -5927,6 +5983,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_TIMESTEP_EMBEDDING:
ggml_vk_timestep_embedding(ctx, compute_ctx, src0, node, dryrun); ggml_vk_timestep_embedding(ctx, compute_ctx, src0, node, dryrun);
break;
case GGML_OP_POOL_2D:
ggml_vk_pool_2d(ctx, compute_ctx, src0, node, dryrun);
break; break;
case GGML_OP_LEAKY_RELU: case GGML_OP_LEAKY_RELU:
ggml_vk_leaky_relu(ctx, compute_ctx, src0, node, dryrun); ggml_vk_leaky_relu(ctx, compute_ctx, src0, node, dryrun);
@ -6018,6 +6078,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
case GGML_OP_SUM_ROWS: case GGML_OP_SUM_ROWS:
case GGML_OP_IM2COL: case GGML_OP_IM2COL:
case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_POOL_2D:
case GGML_OP_LEAKY_RELU: case GGML_OP_LEAKY_RELU:
case GGML_OP_REPEAT: case GGML_OP_REPEAT:
buf = tensor->buffer; buf = tensor->buffer;
@ -6821,6 +6882,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_OP_SUM_ROWS: case GGML_OP_SUM_ROWS:
case GGML_OP_IM2COL: case GGML_OP_IM2COL:
case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_POOL_2D:
case GGML_OP_LEAKY_RELU: case GGML_OP_LEAKY_RELU:
return true; return true;
default: default:
@ -7334,6 +7396,16 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
const int32_t dim = tensor->op_params[0]; const int32_t dim = tensor->op_params[0];
const int32_t max_period = tensor->op_params[1]; const int32_t max_period = tensor->op_params[1];
tensor_clone = ggml_timestep_embedding(ggml_ctx, src0_clone, dim, max_period); tensor_clone = ggml_timestep_embedding(ggml_ctx, src0_clone, dim, max_period);
} else if (tensor->op == GGML_OP_POOL_2D) {
enum ggml_op_pool op = static_cast<ggml_op_pool>(dst->op_params[0]);
const int32_t k0 = tensor->op_params[1];
const int32_t k1 = tensor->op_params[2];
const int32_t s0 = tensor->op_params[3];
const int32_t s1 = tensor->op_params[4];
const int32_t p0 = tensor->op_params[5];
const int32_t p1 = tensor->op_params[6];
tensor_clone = ggml_pool_2d(ggml_ctx, src0_clone, op, k0, k1, s0, s1, p0, p1);
} else if (tensor->op == GGML_OP_LEAKY_RELU) { } else if (tensor->op == GGML_OP_LEAKY_RELU) {
const float * op_params = (const float *)tensor->op_params; const float * op_params = (const float *)tensor->op_params;
tensor_clone = ggml_leaky_relu(ggml_ctx, src0_clone, op_params[0], false); tensor_clone = ggml_leaky_relu(ggml_ctx, src0_clone, op_params[0], false);

View file

@ -942,6 +942,36 @@ class tinyBLAS_Q0_AVX {
return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)), _mm_set1_epi8(8)); return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)), _mm_set1_epi8(8));
} }
inline __m256i load(const block_q5_0 *b) {
return _mm256_or_si256(denibble(b->qs), bittobyte(b->qh));
}
inline __m128i load0(const block_q5_0* b) {
const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
uint32_t x32;
memcpy(&x32, b->qh, sizeof(uint32_t));
__m128i qxl = _mm_and_si128(_mm_set1_epi8(15), x);
__m128i bytesl = _mm_cmpeq_epi8(_mm_set1_epi64x(-1),
_mm_or_si128(_mm_set1_epi64x(0x7fbfdfeff7fbfdfe),
_mm_shuffle_epi8(_mm_set1_epi32(x32),
_mm_set_epi64x(0x0101010101010101, 0x0000000000000000))));
bytesl = _mm_andnot_si128(bytesl, _mm_set1_epi8((char)0xF0));
return _mm_or_si128(qxl, bytesl);
}
inline __m128i load1(const block_q5_0* b) {
const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
uint32_t x32;
memcpy(&x32, b->qh, sizeof(uint32_t));
__m128i qxh = _mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4));
__m128i bytesh = _mm_cmpeq_epi8(_mm_set1_epi64x(-1),
_mm_or_si128(_mm_set1_epi64x(0x7fbfdfeff7fbfdfe),
_mm_shuffle_epi8(_mm_set1_epi32(x32),
_mm_set_epi64x(0x0303030303030303, 0x0202020202020202))));
bytesh = _mm_andnot_si128(bytesh, _mm_set1_epi8((char)0xF0));
return _mm_or_si128(qxh, bytesh);
}
inline __m256i load(const block_iq4_nl *b) { inline __m256i load(const block_iq4_nl *b) {
return MM256_SET_M128I(load1(b), load0(b)); return MM256_SET_M128I(load1(b), load0(b));
} }
@ -973,6 +1003,17 @@ class tinyBLAS_Q0_AVX {
_mm_srli_epi16(x, 4), 1)); _mm_srli_epi16(x, 4), 1));
} }
static inline __m256i bittobyte(const uint8_t *p) {
uint32_t x32;
memcpy(&x32, p, sizeof(uint32_t));
__m256i bytes = _mm256_cmpeq_epi8(_mm256_set1_epi64x(-1),
_mm256_or_si256(_mm256_set1_epi64x(0x7fbfdfeff7fbfdfe),
_mm256_shuffle_epi8(_mm256_set1_epi32(x32),
_mm256_set_epi64x(0x0303030303030303, 0x0202020202020202,
0x0101010101010101, 0x0000000000000000))));
return _mm256_andnot_si256(bytes, _mm256_set1_epi8((char)0xF0));
}
const TA *const A; const TA *const A;
const TB *const B; const TB *const B;
TC *const C; TC *const C;
@ -1182,6 +1223,22 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
#endif #endif
} }
case GGML_TYPE_Q5_0: {
if (Btype != GGML_TYPE_Q8_0)
return false;
#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
tinyBLAS_Q0_AVX<block_q5_0, block_q8_0, float> tb{
k, (const block_q5_0 *)A, lda,
(const block_q8_0 *)B, ldb,
(float *)C, ldc,
ith, nth};
tb.matmul(m, n);
return true;
#else
return false;
#endif
}
case GGML_TYPE_IQ4_NL: { case GGML_TYPE_IQ4_NL: {
if (Btype != GGML_TYPE_Q8_0) if (Btype != GGML_TYPE_Q8_0)
return false; return false;

View file

@ -0,0 +1,74 @@
#version 450
#include "types.comp"
#extension GL_EXT_shader_16bit_storage : require
layout(push_constant) uniform parameter {
uint IW; uint IH;
uint OW; uint OH;
uint OC;
uint pelements;
uint op;
int k0; int k1;
int s0; int s1;
int p0; int p1;
} p;
#define BLOCK_SIZE 512
#define FLT_MAX 3.402823466e+38F
#define OP_POOL_MAX 0u
#define OP_POOL_AVG 1u
layout (local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
layout(binding = 0) readonly buffer X {A_TYPE data_a[];};
layout(binding = 1) writeonly buffer D {D_TYPE data_d[];};
void main() {
const uint idx = gl_GlobalInvocationID.x;
if (idx >= p.pelements) {
return;
}
const uint O_HW = p.OW * p.OH;
const uint nc = idx / O_HW;
const uint cur_oh = (idx % O_HW) / p.OW;
const uint cur_ow = (idx % O_HW) % p.OW;
const int start_h = int(cur_oh) * p.s0 - p.p0;
const uint bh = max(start_h, 0);
const uint eh = min(start_h + p.k0, p.IH);
const int start_w = int(cur_ow) * p.s1 - p.p1;
const uint bw = max(start_w, 0);
const uint ew = min(start_w + p.k1, p.IW);
const float scale = 1.0 / float(p.k0 * p.k1);
float res;
if (p.op == OP_POOL_AVG) {
res = 0.0;
} else if (p.op == OP_POOL_MAX) {
res = -FLT_MAX;
} else {
return;
}
#pragma unroll
for (uint i = bh; i < eh; i++) {
#pragma unroll
for (uint j = bw; j < ew; j++) {
const float cur = D_TYPE(data_a[nc * p.IH * p.IW + i * p.IW + j]);
if (p.op == OP_POOL_AVG) {
res += cur * scale;
} else if (p.op == OP_POOL_MAX) {
res = max(res, cur);
}
}
}
data_d[nc * O_HW + cur_oh * p.OW + cur_ow] = res;
}

View file

@ -494,6 +494,10 @@ void process_shaders(std::vector<std::future<void>>& tasks) {
tasks.push_back(std::async(std::launch::async, [=] { tasks.push_back(std::async(std::launch::async, [=] {
string_to_spv("timestep_embedding_f32", "timestep_embedding.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("timestep_embedding_f32", "timestep_embedding.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
})); }));
tasks.push_back(std::async(std::launch::async, [=] {
string_to_spv("pool2d_f32", "pool2d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
}));
} }
void write_output_files() { void write_output_files() {

View file

@ -1089,9 +1089,6 @@ extern "C" {
/// @details Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841 /// @details Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841
LLAMA_API struct llama_sampler * llama_sampler_init_min_p (float p, size_t min_keep); LLAMA_API struct llama_sampler * llama_sampler_init_min_p (float p, size_t min_keep);
/// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
LLAMA_API struct llama_sampler * llama_sampler_init_tail_free (float z, size_t min_keep);
/// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666. /// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
LLAMA_API struct llama_sampler * llama_sampler_init_typical (float p, size_t min_keep); LLAMA_API struct llama_sampler * llama_sampler_init_typical (float p, size_t min_keep);
@ -1143,6 +1140,16 @@ extern "C" {
bool penalize_nl, // consider newlines as a repeatable token bool penalize_nl, // consider newlines as a repeatable token
bool ignore_eos); // ignore the end-of-sequence token bool ignore_eos); // ignore the end-of-sequence token
/// @details DRY sampler, designed by p-e-w, as described in: https://github.com/oobabooga/text-generation-webui/pull/5677, porting Koboldcpp implementation authored by pi6am: https://github.com/LostRuins/koboldcpp/pull/982
LLAMA_API struct llama_sampler * llama_sampler_init_dry(
const struct llama_model * model,
float dry_multiplier,
float dry_base,
int32_t dry_allowed_length,
int32_t dry_penalty_last_n,
const char ** seq_breakers,
size_t num_breakers);
LLAMA_API struct llama_sampler * llama_sampler_init_logit_bias( LLAMA_API struct llama_sampler * llama_sampler_init_logit_bias(
int32_t n_vocab, int32_t n_vocab,
int32_t n_logit_bias, int32_t n_logit_bias,

View file

@ -113,7 +113,7 @@ static void llama_sampler_softmax_impl(llama_token_data_array * cur_p) {
} }
static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k) { static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k) {
// TODO: move bucket sort to separate function so that top_p/tail_free/typical/softmax first is equally fast // TODO: move bucket sort to separate function so that top_p/typical/softmax first is equally fast
// if (k >= (int32_t)cur_p->size) { // if (k >= (int32_t)cur_p->size) {
// return; // return;
// } // }
@ -733,101 +733,6 @@ struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) {
}; };
} }
// tail-free
struct llama_sampler_tail_free {
const float z;
const size_t min_keep;
};
static const char * llama_sampler_tail_free_name(const struct llama_sampler * /*smpl*/) {
return "tail-free";
}
static void llama_sampler_tail_free_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
const auto * ctx = (llama_sampler_tail_free *) smpl->ctx;
if (ctx->z >= 1.0f || cur_p->size <= 2) {
return;
}
llama_sampler_softmax_impl(cur_p);
// Compute the first and second derivatives
std::vector<float> first_derivatives(cur_p->size - 1);
std::vector<float> second_derivatives(cur_p->size - 2);
for (size_t i = 0; i < first_derivatives.size(); ++i) {
first_derivatives[i] = cur_p->data[i].p - cur_p->data[i + 1].p;
}
for (size_t i = 0; i < second_derivatives.size(); ++i) {
second_derivatives[i] = first_derivatives[i] - first_derivatives[i + 1];
}
// Calculate absolute value of second derivatives
for (size_t i = 0; i < second_derivatives.size(); ++i) {
second_derivatives[i] = std::abs(second_derivatives[i]);
}
// Normalize the second derivatives
{
const float second_derivatives_sum = std::accumulate(second_derivatives.begin(), second_derivatives.end(), 0.0f);
if (second_derivatives_sum > 1e-6f) {
for (float & value : second_derivatives) {
value /= second_derivatives_sum;
}
} else {
for (float & value : second_derivatives) {
value = 1.0f / second_derivatives.size();
}
}
}
float cum_sum = 0.0f;
size_t last_idx = cur_p->size;
for (size_t i = 0; i < second_derivatives.size(); ++i) {
cum_sum += second_derivatives[i];
// Check if the running sum is greater than z or if we have kept at least min_keep tokens
if (cum_sum > ctx->z && i >= ctx->min_keep) {
last_idx = i;
break;
}
}
// Resize the output vector to keep only the tokens above the tail location
cur_p->size = last_idx;
}
static struct llama_sampler * llama_sampler_tail_free_clone(const struct llama_sampler * smpl) {
const auto * ctx = (const llama_sampler_tail_free *) smpl->ctx;
return llama_sampler_init_tail_free(ctx->z, ctx->min_keep);
}
static void llama_sampler_tail_free_free(struct llama_sampler * smpl) {
delete (llama_sampler_tail_free *) smpl->ctx;
}
static struct llama_sampler_i llama_sampler_tail_free_i = {
/* .name = */ llama_sampler_tail_free_name,
/* .accept = */ nullptr,
/* .apply = */ llama_sampler_tail_free_apply,
/* .reset = */ nullptr,
/* .clone = */ llama_sampler_tail_free_clone,
/* .free = */ llama_sampler_tail_free_free,
};
struct llama_sampler * llama_sampler_init_tail_free(float z, size_t min_keep) {
return new llama_sampler {
/* .iface = */ &llama_sampler_tail_free_i,
/* .ctx = */ new llama_sampler_tail_free {
/* .z = */ z,
/*. min_keep = */ min_keep,
},
};
}
// typical // typical
struct llama_sampler_typical { struct llama_sampler_typical {
@ -1683,6 +1588,397 @@ struct llama_sampler * llama_sampler_init_penalties(
}; };
} }
// DRY
struct llama_sampler_dry {
int32_t total_context_size;
const float dry_multiplier;
const float dry_base;
const int32_t dry_allowed_length;
const int32_t dry_penalty_last_n;
std::unordered_multimap<llama_token, std::vector<llama_token>> dry_processed_breakers;
std::vector<int> dry_repeat_count;
std::unordered_map<llama_token, int> dry_max_token_repeat;
ring_buffer<llama_token> last_tokens;
};
// Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am)
static void get_overlapping_token_sequences(const llama_vocab & vocab, const std::string& str, std::unordered_multimap<llama_token, std::vector<llama_token>>& token_sequences, int max_tail_len = -1) {
for (llama_token token_id = 0; token_id < (llama_token)vocab.n_vocab; token_id++) {
std::string word = llama_detokenize(vocab, {token_id}, true);
if (word.find(str) != std::string::npos) {
token_sequences.emplace(token_id, std::vector<llama_token>());
} else {
size_t word_len = word.size(), str_len = str.size();
size_t pos = -1;
while ((pos = word.find(str[0], pos + 1)) != std::string::npos) {
bool match = true;
size_t i;
for (i = 1; i < str_len && i + pos < word_len; ++i) {
if (word[pos + i] != str[i]) {
match = false;
break;
}
}
if (match) {
std::vector<llama_token> tokenization = llama_tokenize_internal(vocab, str.substr(i), false, false);
if (max_tail_len >= 0 && tokenization.size() > (size_t)max_tail_len) {
tokenization.resize(max_tail_len);
}
// Ensure we don't already have a duplicate matching tokenization
auto its = token_sequences.equal_range(token_id);
bool found = false;
for (auto it = its.first; it != its.second; ++it) {
if (tokenization == it->second) {
found = true;
break;
}
}
if (!found) {
token_sequences.emplace(token_id, tokenization);
}
}
}
}
}
}
static const char * llama_sampler_dry_name(const struct llama_sampler * /*smpl*/) {
return "dry";
}
static void llama_sampler_dry_accept(struct llama_sampler * smpl, llama_token token) {
auto * ctx = (llama_sampler_dry *) smpl->ctx;
if (ctx->dry_multiplier == 0.0f || ctx->dry_base < 1.0f || ctx->dry_penalty_last_n == 0) {
return;
}
ctx->last_tokens.push_back(token);
}
// Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am)
static void llama_sampler_dry_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
auto * ctx = (llama_sampler_dry *) smpl->ctx;
if (ctx->dry_multiplier == 0.0f || ctx->dry_base < 1.0f || ctx->dry_penalty_last_n == 0) {
return;
}
int32_t effective_dry_penalty_last_n = (ctx->dry_penalty_last_n == -1) ? ctx->total_context_size : std::max(ctx->dry_penalty_last_n, 0);
int last_n_repeat = std::min(std::min((int)ctx->last_tokens.size(), effective_dry_penalty_last_n), ctx->total_context_size);
if (last_n_repeat <= ctx->dry_allowed_length) {
return;
}
ctx->dry_repeat_count.assign(last_n_repeat, 0);
ctx->dry_max_token_repeat.clear();
// Step 1: Look for restart sequences to limit the maximum repetition length.
// Work backwards through the context looking for any token that begins a restart sequence.
//
// The collection `restart_sequences` is a mapping from a "head" token to all "tail"
// sequences that together comprise a restart sequence. This allows us to quickly check
// whether each token is the head of a complete sequence. Most restart sequences are actually
// a single token, and for these the "tail" is an empty vector.
//
// If the token is a "head", test all restart sequences that begin with this token
// (there will often only be one sequence for each token, but if sequences like 'aaaq1' and
// 'aaa1' are used as restart strings, both could start with 'aaa' when tokenized). The
// longest matching sequence (if any) is used to limit the maximum repetition length.
//
// Note that in the case case of a short sequence contained in a longer one, this might fail to
// find the smallest value for `rep_limit`. For example, if 'amniotic' and 'ni' are both used as
// restart sequences, 'ni' will be found first, and since it's shorter it will fail to suppress
// 'otic'. This is a minor issue since fully contained restart sequences are likely to be rare.
//
// This is theoretically worst-case O(N^2) for arbitrary restart sequences, which is why we
// have already clamped the maximum tail sequence length when generating `restart_sequences`.
// With clamping, this scan is O(N) in the context length.
int rep_limit = last_n_repeat;
for (int i = 0; i < last_n_repeat; ++i) {
llama_token token = ctx->last_tokens.rat(i);
auto its = ctx->dry_processed_breakers.equal_range(token);
if (its.first == ctx->dry_processed_breakers.end()) {
continue;
}
int longest_match = -1;
for (auto it = its.first; it != its.second; ++it) {
// Note that (*it) does not contain the head character, so seq_len will be
// the restart sequence length minus 1.
// In the common case of a single-token restart sequence, (*it) will be empty
// and we will trivially match.
int seq_len = (int)it->second.size();
if (seq_len > longest_match && seq_len <= (int)i) {
bool match = true;
for (int offset = 0; offset < seq_len; ++offset) {
// The -1 when indexing `last_tokens` is because we already matched the head.
if (it->second[offset] != ctx->last_tokens.rat(i - offset - 1)) {
match = false;
break;
}
}
if (match) {
longest_match = seq_len;
}
}
}
if (longest_match >= 0) {
// We found a restart sequence starting `i` tokens from the end and continuing for
// `longest_match` tokens.
rep_limit = i - longest_match;
break;
}
}
if (rep_limit < ctx->dry_allowed_length) {
return;
}
// Step 2: Iterate in reverse over the last N tokens of the context, using the "Z-algorithm" (in
// the reverse direction) to efficiently compute the positions and lengths of suffixes appearing
// elsewhere in the context. We limit the suffix length to `rep_limit` to respect restart sequences.
//
// This algorithm is not currently documented on Wikipedia, but there is a clear description here:
// https://ivanyu.me/blog/2014/10/15/z-algorithm/
//
// The code below is adapted from the public domain implementation by the same author here:
// https://github.com/ivanyu/string-algorithms/blob/master/z_algorithm.py
//
// Example:
// Last N tokens: a b c c b c y a b c
// Repeat counts: 0 0 3 1 0 2 0 0 0 0
// ^
// This `3` means that the last three tokens of the context (a b c) also appear here.
//
// This step is worst case O(N) since the Z-algorithm is linear, despite the appearance of nested
// for/while loops. This can be seen by observing that the `lt` and `rt` bounds are set after each
// repeated suffix is detected (i.e. after each while loop when n > 0). These bound variables
// ensure that the inner while loops only examine each token in the context once as the outer
// for loop iterates over the context.
{
const int last = last_n_repeat - 1;
int rt = 0, lt = 0;
for (int k = 1; k < last_n_repeat; ++k) {
if (k > rt) {
// If k is outside the current Z-box, do naive computation.
int n = 0;
while (n + k < last_n_repeat && ctx->last_tokens.rat(n) == ctx->last_tokens.rat(n+k)) {
++n;
}
ctx->dry_repeat_count[last - k] = std::min(n, rep_limit);
if (n > 0) {
lt = k;
rt = k+n-1;
}
} else {
// If k is inside the current Z-box, consider two cases.
int p = k - lt; // Pair index.
int right_part_len = rt - k + 1;
if (ctx->dry_repeat_count[last - p] < right_part_len) {
int n = std::min(ctx->dry_repeat_count[last - p], rep_limit);
ctx->dry_repeat_count[last - k] = n;
} else {
int i = rt + 1;
while (i < last_n_repeat && ctx->last_tokens.rat(i) == ctx->last_tokens.rat(i - k)) {
i += 1;
}
int n = std::min(i - k, rep_limit);
ctx->dry_repeat_count[last - k] = n;
lt = k;
rt = i - 1;
}
}
}
}
// Step 3: Iterate over dry_repeat_count and last_tokens, examining the maximum repeat length
// that would be generated by emitting each new token that would extend a sequence.
//
// Following the same example as above:
// Last N tokens: a b c c b c y a b c
// Repeat counts: 0 0 3 1 0 2 0 0 0 0
//
// For each non-zero, look ahead one token. This token, if emitted, would extend the repetition.
// c: 3 -> 4 (from `a b c` to `a b c c`)
// b: 1 -> 2 (from `c` to `c b`)
// y: 2 -> 3 (from `b c` to `b c y`)
for (int i = 0; i < last_n_repeat - 1; ++i) {
int repeat_len = ctx->dry_repeat_count[i];
if (repeat_len >= ctx->dry_allowed_length) {
// This token ends a repeat, so the next token would continue one.
// By convention, the value of `repeat_len` only includes the tokens currently
// in the context, not the new token that would be added.
llama_token token = ctx->last_tokens.rat(last_n_repeat - 2 - i);
// Track the maximum sequence ending in this token.
const auto& it = ctx->dry_max_token_repeat.find(token);
if (it == ctx->dry_max_token_repeat.end() || it->second < repeat_len) {
ctx->dry_max_token_repeat[token] = repeat_len;
}
}
}
// Step 4: Apply logit penalties based on the maximum repeat length for relevant tokens.
// Prevent floating point overflow in `pow(penalty_base, exponent)` by clamping to `max_exponent`.
// Compute it from `penalty_base` and the approximate log of `std::numeric_limits<float>::max()`
const float FLOAT_MAX_LOG = 88.7228391f;
int max_exponent = 0;
if (ctx->dry_base > 1.000001f) {
max_exponent = FLOAT_MAX_LOG / std::log(ctx->dry_base);
}
for (size_t i = 0; i < cur_p->size; ++i) {
const auto& af_kvp = ctx->dry_max_token_repeat.find(cur_p->data[i].id);
if (af_kvp != ctx->dry_max_token_repeat.end()) {
// Check all sequence breakers starting with this token
auto range = ctx->dry_processed_breakers.equal_range(cur_p->data[i].id);
bool is_single_token_breaker = false;
for (auto it = range.first; it != range.second; ++it) {
if (it->second.empty()) {
is_single_token_breaker = true;
break;
}
}
// Apply penalty only if it's not a single-token sequence breaker
if (!is_single_token_breaker) {
int repeat_exp = af_kvp->second - ctx->dry_allowed_length;
if (max_exponent > 0 && repeat_exp > max_exponent) {
repeat_exp = max_exponent;
}
float penalty = ctx->dry_multiplier * std::pow(ctx->dry_base, repeat_exp);
cur_p->data[i].logit -= penalty;
}
}
}
cur_p->sorted = false;
}
static void llama_sampler_dry_reset(struct llama_sampler * smpl) {
auto * ctx = (llama_sampler_dry *) smpl->ctx;
ctx->last_tokens.clear();
ctx->dry_repeat_count.clear();
ctx->dry_max_token_repeat.clear();
}
static struct llama_sampler * llama_sampler_dry_clone(const struct llama_sampler * smpl) {
const auto * ctx = (llama_sampler_dry *) smpl->ctx;
// nullptr is passed as vocab because it is only needed for raw sequence breaker processing, which we have already done and will be copying
auto * result = llama_sampler_init_dry(nullptr, ctx->dry_multiplier, ctx->dry_base, ctx->dry_allowed_length, ctx->dry_penalty_last_n, NULL, 0);
// Copy the state, including the processed breakers
{
auto * result_ctx = (llama_sampler_dry *) result->ctx;
result_ctx->dry_processed_breakers = ctx->dry_processed_breakers;
result_ctx->dry_repeat_count = ctx->dry_repeat_count;
result_ctx->dry_max_token_repeat = ctx->dry_max_token_repeat;
result_ctx->last_tokens = ctx->last_tokens;
}
return result;
}
static void llama_sampler_dry_free(struct llama_sampler * smpl) {
delete (llama_sampler_dry *) smpl->ctx;
}
static struct llama_sampler_i llama_sampler_dry_i = {
/* .name = */ llama_sampler_dry_name,
/* .accept = */ llama_sampler_dry_accept,
/* .apply = */ llama_sampler_dry_apply,
/* .reset = */ llama_sampler_dry_reset,
/* .clone = */ llama_sampler_dry_clone,
/* .free = */ llama_sampler_dry_free,
};
struct llama_sampler * llama_sampler_init_dry_impl(const struct llama_vocab & vocab, int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) {
int32_t effective_dry_penalty_last_n = (dry_penalty_last_n == -1) ? context_size : std::max(dry_penalty_last_n, 0);
std::unordered_multimap<llama_token, std::vector<llama_token>> processed_breakers;
const int MAX_CHAR_LEN = 40;
const int MAX_SEQ_LEN = 20;
const bool dry_enabled = (dry_multiplier != 0.0f && dry_base >= 1.0f && dry_penalty_last_n != 0);
if (dry_enabled && seq_breakers != nullptr && num_breakers > 0) {
// Process sequence breakers
for (size_t i = 0; i < num_breakers; ++i) {
if (seq_breakers[i] == nullptr || std::strlen(seq_breakers[i]) == 0) {
LLAMA_LOG_WARN("skipping null or empty DRY sequence breaker at index %zu\n", i);
continue;
}
std::string sequence_break(seq_breakers[i]);
if (sequence_break.empty()) {
LLAMA_LOG_WARN("skipping empty DRY sequence breaker\n");
continue;
}
if (sequence_break.size() > MAX_CHAR_LEN) {
LLAMA_LOG_WARN("truncating DRY sequence breaker to %d characters\n", MAX_CHAR_LEN);
sequence_break.resize(MAX_CHAR_LEN);
}
get_overlapping_token_sequences(vocab, sequence_break, processed_breakers, MAX_SEQ_LEN);
}
}
return new llama_sampler {
/* .iface = */ &llama_sampler_dry_i,
/* .ctx = */ new llama_sampler_dry {
/* .total_context_size = */ context_size,
/* .dry_multiplier = */ dry_multiplier,
/* .dry_base = */ dry_base,
/* .dry_allowed_length = */ dry_allowed_length,
/* .dry_penalty_last_n = */ dry_penalty_last_n,
/* .dry_processed_breakers = */ std::move(processed_breakers),
/* .dry_repeat_count = */ dry_enabled ? std::vector<int>(effective_dry_penalty_last_n, 0) : std::vector<int>{},
/* .dry_max_token_repeat = */ {},
/* .last_tokens = */ dry_enabled ? ring_buffer<llama_token>(effective_dry_penalty_last_n) : ring_buffer<llama_token>(0),
},
};
}
// wrapper for test-sampling.cpp
struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const std::vector<std::vector<llama_token>>& seq_breakers) {
llama_vocab dummy_vocab;
auto * result = llama_sampler_init_dry_impl(dummy_vocab, context_size, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, NULL, 0);
auto * ctx = (llama_sampler_dry *) result->ctx;
// Process the token-based sequence breakers
ctx->dry_processed_breakers.clear();
if (seq_breakers.empty()) {
LLAMA_LOG_WARN("empty DRY sequence breakers list in llama_sampler_init_dry_testing\n");
} else {
for (const auto& breaker : seq_breakers) {
if (breaker.empty()) {
LLAMA_LOG_WARN("skipping DRY empty sequence breaker\n");
continue;
}
llama_token head_token = breaker[0];
std::vector<llama_token> tail_tokens(breaker.begin() + 1, breaker.end());
ctx->dry_processed_breakers.emplace(head_token, std::move(tail_tokens));
}
if (ctx->dry_processed_breakers.empty()) {
LLAMA_LOG_WARN("no valid DRY sequence breakers processed in llama_sampler_init_dry_testing\n");
}
}
return result;
}
// logit-bias // logit-bias
struct llama_sampler_logit_bias { struct llama_sampler_logit_bias {

View file

@ -28,3 +28,21 @@ struct llama_sampler * llama_sampler_init_grammar_impl(
struct llama_sampler * llama_sampler_init_infill_impl( struct llama_sampler * llama_sampler_init_infill_impl(
const struct llama_vocab & vocab); const struct llama_vocab & vocab);
struct llama_sampler * llama_sampler_init_dry_impl(
const struct llama_vocab & vocab,
int32_t context_size,
float dry_multiplier,
float dry_base,
int32_t dry_allowed_length,
int32_t dry_penalty_last_n,
const char ** seq_breakers,
size_t num_breakers);
struct llama_sampler * llama_sampler_init_dry_testing(
int32_t context_size,
float dry_multiplier,
float dry_base,
int32_t dry_allowed_length,
int32_t dry_penalty_last_n,
const std::vector<std::vector<llama_token>>& seq_breakers);

View file

@ -2226,3 +2226,19 @@ int32_t llama_detokenize_impl(
return total <= text_len_max ? total : -total; return total <= text_len_max ? total : -total;
} }
std::string llama_detokenize(const struct llama_vocab & vocab, const std::vector<llama_token> & tokens, bool special) {
std::string text;
text.resize(std::max(text.capacity(), tokens.size()));
int32_t n_chars = llama_detokenize_impl(vocab, tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
if (n_chars < 0) {
text.resize(-n_chars);
n_chars = llama_detokenize_impl(vocab, tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
GGML_ASSERT(n_chars <= (int32_t)text.size()); // whitespace trimming is performed after per-token detokenization
}
text.resize(n_chars);
// NOTE: the original tokenizer decodes bytes after collecting the pieces.
return text;
}

View file

@ -163,3 +163,8 @@ int32_t llama_detokenize_impl(
int32_t text_len_max, int32_t text_len_max,
bool remove_special, bool remove_special,
bool unparse_special); bool unparse_special);
std::string llama_detokenize(
const struct llama_vocab & vocab,
const std::vector<llama_token> & tokens,
bool special);

View file

@ -9670,20 +9670,16 @@ static struct ggml_tensor * llm_build_kqv(
cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias, cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f); hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_GEMMA2) { ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
}
cur = ggml_reshape_2d(ctx, cur, n_embd_head_v*n_head, n_tokens); cur = ggml_reshape_2d(ctx, cur, n_embd_head_v*n_head, n_tokens);
} else { } else {
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
cb(kq, "kq", il); cb(kq, "kq", il);
if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2 || model.arch == LLM_ARCH_NEMOTRON || model.arch == LLM_ARCH_CHATGLM) { // note: this op tends to require high floating point range
// for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs // while for some models F16 is enough, for others it is not, so we default to F32 here
// ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847 ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
}
if (model.arch == LLM_ARCH_GROK) { if (model.arch == LLM_ARCH_GROK) {
// need to do the following: // need to do the following:
@ -9692,9 +9688,6 @@ static struct ggml_tensor * llm_build_kqv(
// kq = 30 * tanh(kq / 30) // kq = 30 * tanh(kq / 30)
// before the softmax below // before the softmax below
//try from phi2
//ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
kq = ggml_tanh(ctx, ggml_scale(ctx, kq, 0.08838834764831845f/30.0f)); kq = ggml_tanh(ctx, ggml_scale(ctx, kq, 0.08838834764831845f/30.0f));
kq = ggml_scale(ctx, kq, 30); kq = ggml_scale(ctx, kq, 30);
} }
@ -21793,6 +21786,16 @@ static int32_t llama_chat_apply_template_internal(
ss << message->content << "\n\n"; ss << message->content << "\n\n";
} }
} }
} else if (tmpl == "granite" || tmpl_contains("<|start_of_role|>")) {
// IBM Granite template
for (const auto & message : chat) {
std::string role(message->role);
ss << "<|start_of_role|>" << role << "<|end_of_role|>"
<< message->content << "<|end_of_text|>\n";
}
if (add_ass) {
ss << "<|start_of_role|>assistant<|end_of_role|>\n";
}
} else { } else {
// template not supported // template not supported
return -1; return -1;
@ -21855,6 +21858,10 @@ struct llama_sampler * llama_sampler_init_infill(const struct llama_model * mode
return llama_sampler_init_infill_impl(model->vocab); return llama_sampler_init_infill_impl(model->vocab);
} }
struct llama_sampler * llama_sampler_init_dry(const struct llama_model * model, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) {
return llama_sampler_init_dry_impl(model->vocab, llama_n_ctx_train(model), dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, seq_breakers, num_breakers);
}
// //
// model split // model split
// //