mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-11 17:44:38 +00:00
note: also has support for completion tokens count
This commit is contained in:
commit
a46f8acd03
31 changed files with 138676 additions and 137399 deletions
|
@ -129,13 +129,13 @@ static void common_params_handle_model_default(common_params & params) {
|
||||||
}
|
}
|
||||||
params.hf_file = params.model;
|
params.hf_file = params.model;
|
||||||
} else if (params.model.empty()) {
|
} else if (params.model.empty()) {
|
||||||
params.model = fs_get_cache_file(string_split(params.hf_file, '/').back());
|
params.model = fs_get_cache_file(string_split<std::string>(params.hf_file, '/').back());
|
||||||
}
|
}
|
||||||
} else if (!params.model_url.empty()) {
|
} else if (!params.model_url.empty()) {
|
||||||
if (params.model.empty()) {
|
if (params.model.empty()) {
|
||||||
auto f = string_split(params.model_url, '#').front();
|
auto f = string_split<std::string>(params.model_url, '#').front();
|
||||||
f = string_split(f, '?').front();
|
f = string_split<std::string>(f, '?').front();
|
||||||
params.model = fs_get_cache_file(string_split(f, '/').back());
|
params.model = fs_get_cache_file(string_split<std::string>(f, '/').back());
|
||||||
}
|
}
|
||||||
} else if (params.model.empty()) {
|
} else if (params.model.empty()) {
|
||||||
params.model = DEFAULT_MODEL_PATH;
|
params.model = DEFAULT_MODEL_PATH;
|
||||||
|
@ -252,6 +252,9 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
|
||||||
for (auto & antiprompt : params.antiprompt) {
|
for (auto & antiprompt : params.antiprompt) {
|
||||||
string_process_escapes(antiprompt);
|
string_process_escapes(antiprompt);
|
||||||
}
|
}
|
||||||
|
for (auto & seq_breaker : params.sparams.dry_sequence_breakers) {
|
||||||
|
string_process_escapes(seq_breaker);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!params.kv_overrides.empty()) {
|
if (!params.kv_overrides.empty()) {
|
||||||
|
@ -880,7 +883,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||||
{"--samplers"}, "SAMPLERS",
|
{"--samplers"}, "SAMPLERS",
|
||||||
string_format("samplers that will be used for generation in the order, separated by \';\'\n(default: %s)", sampler_type_names.c_str()),
|
string_format("samplers that will be used for generation in the order, separated by \';\'\n(default: %s)", sampler_type_names.c_str()),
|
||||||
[](common_params & params, const std::string & value) {
|
[](common_params & params, const std::string & value) {
|
||||||
const auto sampler_names = string_split(value, ';');
|
const auto sampler_names = string_split<std::string>(value, ';');
|
||||||
params.sparams.samplers = common_sampler_types_from_names(sampler_names, true);
|
params.sparams.samplers = common_sampler_types_from_names(sampler_names, true);
|
||||||
}
|
}
|
||||||
).set_sparam());
|
).set_sparam());
|
||||||
|
@ -941,13 +944,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||||
params.sparams.min_p = std::stof(value);
|
params.sparams.min_p = std::stof(value);
|
||||||
}
|
}
|
||||||
).set_sparam());
|
).set_sparam());
|
||||||
add_opt(common_arg(
|
|
||||||
{"--tfs"}, "N",
|
|
||||||
string_format("tail free sampling, parameter z (default: %.1f, 1.0 = disabled)", (double)params.sparams.tfs_z),
|
|
||||||
[](common_params & params, const std::string & value) {
|
|
||||||
params.sparams.tfs_z = std::stof(value);
|
|
||||||
}
|
|
||||||
).set_sparam());
|
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
{"--xtc-probability"}, "N",
|
{"--xtc-probability"}, "N",
|
||||||
string_format("xtc probability (default: %.1f, 0.0 = disabled)", (double)params.sparams.xtc_probability),
|
string_format("xtc probability (default: %.1f, 0.0 = disabled)", (double)params.sparams.xtc_probability),
|
||||||
|
@ -998,6 +994,64 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||||
params.sparams.penalty_freq = std::stof(value);
|
params.sparams.penalty_freq = std::stof(value);
|
||||||
}
|
}
|
||||||
).set_sparam());
|
).set_sparam());
|
||||||
|
add_opt(common_arg(
|
||||||
|
{"--dry-multiplier"}, "N",
|
||||||
|
string_format("set DRY sampling multiplier (default: %.1f, 0.0 = disabled)", (double)params.sparams.dry_multiplier),
|
||||||
|
[](common_params & params, const std::string & value) {
|
||||||
|
params.sparams.dry_multiplier = std::stof(value);
|
||||||
|
}
|
||||||
|
).set_sparam());
|
||||||
|
add_opt(common_arg(
|
||||||
|
{"--dry-base"}, "N",
|
||||||
|
string_format("set DRY sampling base value (default: %.2f)", (double)params.sparams.dry_base),
|
||||||
|
[](common_params & params, const std::string & value) {
|
||||||
|
float potential_base = std::stof(value);
|
||||||
|
if (potential_base >= 1.0f)
|
||||||
|
{
|
||||||
|
params.sparams.dry_base = potential_base;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
).set_sparam());
|
||||||
|
add_opt(common_arg(
|
||||||
|
{"--dry-allowed-length"}, "N",
|
||||||
|
string_format("set allowed length for DRY sampling (default: %d)", params.sparams.dry_allowed_length),
|
||||||
|
[](common_params & params, int value) {
|
||||||
|
params.sparams.dry_allowed_length = value;
|
||||||
|
}
|
||||||
|
).set_sparam());
|
||||||
|
add_opt(common_arg(
|
||||||
|
{"--dry-penalty-last-n"}, "N",
|
||||||
|
string_format("set DRY penalty for the last n tokens (default: %d, 0 = disable, -1 = context size)", params.sparams.dry_penalty_last_n),
|
||||||
|
[](common_params & params, int value) {
|
||||||
|
params.sparams.dry_penalty_last_n = value;
|
||||||
|
}
|
||||||
|
).set_sparam());
|
||||||
|
add_opt(common_arg(
|
||||||
|
{"--dry-sequence-breaker"}, "STRING",
|
||||||
|
string_format("add sequence breaker for DRY sampling, clearing out default breakers (%s) in the process; use \"none\" to not use any sequence breakers\n",
|
||||||
|
params.sparams.dry_sequence_breakers.empty() ? "none" :
|
||||||
|
std::accumulate(std::next(params.sparams.dry_sequence_breakers.begin()),
|
||||||
|
params.sparams.dry_sequence_breakers.end(),
|
||||||
|
std::string("'") + (params.sparams.dry_sequence_breakers[0] == "\n" ? "\\n" : params.sparams.dry_sequence_breakers[0]) + "'",
|
||||||
|
[](const std::string& a, const std::string& b) {
|
||||||
|
std::string formatted_b = (b == "\n") ? "\\n" : b;
|
||||||
|
return a + ", '" + formatted_b + "'";
|
||||||
|
}).c_str()),
|
||||||
|
[](common_params & params, const std::string & value) {
|
||||||
|
static bool defaults_cleared = false;
|
||||||
|
|
||||||
|
if (!defaults_cleared) {
|
||||||
|
params.sparams.dry_sequence_breakers.clear();
|
||||||
|
defaults_cleared = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (value == "none") {
|
||||||
|
params.sparams.dry_sequence_breakers.clear();
|
||||||
|
} else {
|
||||||
|
params.sparams.dry_sequence_breakers.emplace_back(value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
).set_sparam());
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
{"--dynatemp-range"}, "N",
|
{"--dynatemp-range"}, "N",
|
||||||
string_format("dynamic temperature range (default: %.1f, 0.0 = disabled)", (double)params.sparams.dynatemp_range),
|
string_format("dynamic temperature range (default: %.1f, 0.0 = disabled)", (double)params.sparams.dynatemp_range),
|
||||||
|
@ -1014,7 +1068,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||||
).set_sparam());
|
).set_sparam());
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
{"--mirostat"}, "N",
|
{"--mirostat"}, "N",
|
||||||
string_format("use Mirostat sampling.\nTop K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.\n"
|
string_format("use Mirostat sampling.\nTop K, Nucleus and Locally Typical samplers are ignored if used.\n"
|
||||||
"(default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)", params.sparams.mirostat),
|
"(default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)", params.sparams.mirostat),
|
||||||
[](common_params & params, int value) {
|
[](common_params & params, int value) {
|
||||||
params.sparams.mirostat = value;
|
params.sparams.mirostat = value;
|
||||||
|
|
|
@ -418,19 +418,6 @@ std::string string_format(const char * fmt, ...) {
|
||||||
return std::string(buf.data(), size);
|
return std::string(buf.data(), size);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::string> string_split(std::string input, char separator) {
|
|
||||||
std::vector<std::string> parts;
|
|
||||||
size_t separator_pos = input.find(separator);
|
|
||||||
while (separator_pos != std::string::npos) {
|
|
||||||
std::string part = input.substr(0, separator_pos);
|
|
||||||
parts.emplace_back(part);
|
|
||||||
input = input.substr(separator_pos + 1);
|
|
||||||
separator_pos = input.find(separator);
|
|
||||||
}
|
|
||||||
parts.emplace_back(input);
|
|
||||||
return parts;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string string_strip(const std::string & str) {
|
std::string string_strip(const std::string & str) {
|
||||||
size_t start = 0;
|
size_t start = 0;
|
||||||
size_t end = str.size();
|
size_t end = str.size();
|
||||||
|
@ -2021,6 +2008,10 @@ void yaml_dump_non_result_info(FILE * stream, const common_params & params, cons
|
||||||
fprintf(stream, "chunks: %d # default: -1 (unlimited)\n", params.n_chunks);
|
fprintf(stream, "chunks: %d # default: -1 (unlimited)\n", params.n_chunks);
|
||||||
fprintf(stream, "color: %s # default: false\n", params.use_color ? "true" : "false");
|
fprintf(stream, "color: %s # default: false\n", params.use_color ? "true" : "false");
|
||||||
fprintf(stream, "ctx_size: %d # default: 512\n", params.n_ctx);
|
fprintf(stream, "ctx_size: %d # default: 512\n", params.n_ctx);
|
||||||
|
fprintf(stream, "dry_allowed_length: %d # default: 2\n", sparams.dry_allowed_length);
|
||||||
|
fprintf(stream, "dry_base: %.2f # default: 1.75\n", sparams.dry_base);
|
||||||
|
fprintf(stream, "dry_multiplier: %.1f # default: 0.0\n", sparams.dry_multiplier);
|
||||||
|
fprintf(stream, "dry_penalty_last_n: %d # default: -1 (0 = disable, -1 = context size)\n", sparams.dry_penalty_last_n);
|
||||||
fprintf(stream, "escape: %s # default: false\n", params.escape ? "true" : "false");
|
fprintf(stream, "escape: %s # default: false\n", params.escape ? "true" : "false");
|
||||||
fprintf(stream, "file: # never logged, see prompt instead. Can still be specified for input.\n");
|
fprintf(stream, "file: # never logged, see prompt instead. Can still be specified for input.\n");
|
||||||
fprintf(stream, "frequency_penalty: %f # default: 0.0 \n", sparams.penalty_freq);
|
fprintf(stream, "frequency_penalty: %f # default: 0.0 \n", sparams.penalty_freq);
|
||||||
|
@ -2101,7 +2092,6 @@ void yaml_dump_non_result_info(FILE * stream, const common_params & params, cons
|
||||||
const std::vector<float> tensor_split_vector(params.tensor_split, params.tensor_split + llama_max_devices());
|
const std::vector<float> tensor_split_vector(params.tensor_split, params.tensor_split + llama_max_devices());
|
||||||
yaml_dump_vector_float(stream, "tensor_split", tensor_split_vector);
|
yaml_dump_vector_float(stream, "tensor_split", tensor_split_vector);
|
||||||
|
|
||||||
fprintf(stream, "tfs: %f # default: 1.0\n", sparams.tfs_z);
|
|
||||||
fprintf(stream, "threads: %d # default: %u\n", params.cpuparams.n_threads, std::thread::hardware_concurrency());
|
fprintf(stream, "threads: %d # default: %u\n", params.cpuparams.n_threads, std::thread::hardware_concurrency());
|
||||||
fprintf(stream, "top_k: %d # default: 40\n", sparams.top_k);
|
fprintf(stream, "top_k: %d # default: 40\n", sparams.top_k);
|
||||||
fprintf(stream, "top_p: %f # default: 0.95\n", sparams.top_p);
|
fprintf(stream, "top_p: %f # default: 0.95\n", sparams.top_p);
|
||||||
|
|
|
@ -80,14 +80,15 @@ enum llama_example {
|
||||||
|
|
||||||
enum common_sampler_type {
|
enum common_sampler_type {
|
||||||
COMMON_SAMPLER_TYPE_NONE = 0,
|
COMMON_SAMPLER_TYPE_NONE = 0,
|
||||||
COMMON_SAMPLER_TYPE_TOP_K = 1,
|
COMMON_SAMPLER_TYPE_DRY = 1,
|
||||||
COMMON_SAMPLER_TYPE_TOP_P = 2,
|
COMMON_SAMPLER_TYPE_TOP_K = 2,
|
||||||
COMMON_SAMPLER_TYPE_MIN_P = 3,
|
COMMON_SAMPLER_TYPE_TOP_P = 3,
|
||||||
COMMON_SAMPLER_TYPE_TFS_Z = 4,
|
COMMON_SAMPLER_TYPE_MIN_P = 4,
|
||||||
COMMON_SAMPLER_TYPE_TYPICAL_P = 5,
|
//COMMON_SAMPLER_TYPE_TFS_Z = 5,
|
||||||
COMMON_SAMPLER_TYPE_TEMPERATURE = 6,
|
COMMON_SAMPLER_TYPE_TYPICAL_P = 6,
|
||||||
COMMON_SAMPLER_TYPE_XTC = 7,
|
COMMON_SAMPLER_TYPE_TEMPERATURE = 7,
|
||||||
COMMON_SAMPLER_TYPE_INFILL = 8,
|
COMMON_SAMPLER_TYPE_XTC = 8,
|
||||||
|
COMMON_SAMPLER_TYPE_INFILL = 9,
|
||||||
};
|
};
|
||||||
|
|
||||||
// dimensionality reduction methods, used by cvector-generator
|
// dimensionality reduction methods, used by cvector-generator
|
||||||
|
@ -100,34 +101,39 @@ enum dimre_method {
|
||||||
struct common_sampler_params {
|
struct common_sampler_params {
|
||||||
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler
|
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler
|
||||||
|
|
||||||
int32_t n_prev = 64; // number of previous tokens to remember
|
int32_t n_prev = 64; // number of previous tokens to remember
|
||||||
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
|
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
|
||||||
int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens
|
int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens
|
||||||
int32_t top_k = 40; // <= 0 to use vocab size
|
int32_t top_k = 40; // <= 0 to use vocab size
|
||||||
float top_p = 0.95f; // 1.0 = disabled
|
float top_p = 0.95f; // 1.0 = disabled
|
||||||
float min_p = 0.05f; // 0.0 = disabled
|
float min_p = 0.05f; // 0.0 = disabled
|
||||||
float xtc_probability = 0.00f; // 0.0 = disabled
|
float xtc_probability = 0.00f; // 0.0 = disabled
|
||||||
float xtc_threshold = 0.10f; // > 0.5 disables XTC
|
float xtc_threshold = 0.10f; // > 0.5 disables XTC
|
||||||
float tfs_z = 1.00f; // 1.0 = disabled
|
float typ_p = 1.00f; // typical_p, 1.0 = disabled
|
||||||
float typ_p = 1.00f; // typical_p, 1.0 = disabled
|
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
|
||||||
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
|
float dynatemp_range = 0.00f; // 0.0 = disabled
|
||||||
float dynatemp_range = 0.00f; // 0.0 = disabled
|
float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
|
||||||
float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
|
int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
|
||||||
int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
|
float penalty_repeat = 1.00f; // 1.0 = disabled
|
||||||
float penalty_repeat = 1.00f; // 1.0 = disabled
|
float penalty_freq = 0.00f; // 0.0 = disabled
|
||||||
float penalty_freq = 0.00f; // 0.0 = disabled
|
float penalty_present = 0.00f; // 0.0 = disabled
|
||||||
float penalty_present = 0.00f; // 0.0 = disabled
|
float dry_multiplier = 0.0f; // 0.0 = disabled; DRY repetition penalty for tokens extending repetition:
|
||||||
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
|
float dry_base = 1.75f; // 0.0 = disabled; multiplier * base ^ (length of sequence before token - allowed length)
|
||||||
float mirostat_tau = 5.00f; // target entropy
|
int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty
|
||||||
float mirostat_eta = 0.10f; // learning rate
|
int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size)
|
||||||
bool penalize_nl = false; // consider newlines as a repeatable token
|
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
|
||||||
bool ignore_eos = false;
|
float mirostat_tau = 5.00f; // target entropy
|
||||||
bool no_perf = false; // disable performance metrics
|
float mirostat_eta = 0.10f; // learning rate
|
||||||
|
bool penalize_nl = false; // consider newlines as a repeatable token
|
||||||
|
bool ignore_eos = false;
|
||||||
|
bool no_perf = false; // disable performance metrics
|
||||||
|
|
||||||
|
std::vector<std::string> dry_sequence_breakers = {"\n", ":", "\"", "*"}; // default sequence breakers for DRY
|
||||||
|
|
||||||
|
|
||||||
std::vector<enum common_sampler_type> samplers = {
|
std::vector<enum common_sampler_type> samplers = {
|
||||||
|
COMMON_SAMPLER_TYPE_DRY,
|
||||||
COMMON_SAMPLER_TYPE_TOP_K,
|
COMMON_SAMPLER_TYPE_TOP_K,
|
||||||
COMMON_SAMPLER_TYPE_TFS_Z,
|
|
||||||
COMMON_SAMPLER_TYPE_TYPICAL_P,
|
COMMON_SAMPLER_TYPE_TYPICAL_P,
|
||||||
COMMON_SAMPLER_TYPE_TOP_P,
|
COMMON_SAMPLER_TYPE_TOP_P,
|
||||||
COMMON_SAMPLER_TYPE_MIN_P,
|
COMMON_SAMPLER_TYPE_MIN_P,
|
||||||
|
@ -376,8 +382,6 @@ bool set_process_priority(enum ggml_sched_priority prio);
|
||||||
LLAMA_COMMON_ATTRIBUTE_FORMAT(1, 2)
|
LLAMA_COMMON_ATTRIBUTE_FORMAT(1, 2)
|
||||||
std::string string_format(const char * fmt, ...);
|
std::string string_format(const char * fmt, ...);
|
||||||
|
|
||||||
std::vector<std::string> string_split(std::string input, char separator);
|
|
||||||
|
|
||||||
std::string string_strip(const std::string & str);
|
std::string string_strip(const std::string & str);
|
||||||
std::string string_get_sortable_timestamp();
|
std::string string_get_sortable_timestamp();
|
||||||
|
|
||||||
|
@ -385,6 +389,7 @@ void string_replace_all(std::string & s, const std::string & search, const std::
|
||||||
|
|
||||||
template<class T>
|
template<class T>
|
||||||
static std::vector<T> string_split(const std::string & str, char delim) {
|
static std::vector<T> string_split(const std::string & str, char delim) {
|
||||||
|
static_assert(!std::is_same<T, std::string>::value, "Please use the specialized version for std::string");
|
||||||
std::vector<T> values;
|
std::vector<T> values;
|
||||||
std::istringstream str_stream(str);
|
std::istringstream str_stream(str);
|
||||||
std::string token;
|
std::string token;
|
||||||
|
@ -397,6 +402,22 @@ static std::vector<T> string_split(const std::string & str, char delim) {
|
||||||
return values;
|
return values;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
std::vector<std::string> string_split<std::string>(const std::string & input, char separator)
|
||||||
|
{
|
||||||
|
std::vector<std::string> parts;
|
||||||
|
size_t begin_pos = 0;
|
||||||
|
size_t separator_pos = input.find(separator);
|
||||||
|
while (separator_pos != std::string::npos) {
|
||||||
|
std::string part = input.substr(begin_pos, separator_pos - begin_pos);
|
||||||
|
parts.emplace_back(part);
|
||||||
|
begin_pos = separator_pos + 1;
|
||||||
|
separator_pos = input.find(separator, begin_pos);
|
||||||
|
}
|
||||||
|
parts.emplace_back(input.substr(begin_pos, separator_pos - begin_pos));
|
||||||
|
return parts;
|
||||||
|
}
|
||||||
|
|
||||||
bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides);
|
bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides);
|
||||||
void string_process_escapes(std::string & input);
|
void string_process_escapes(std::string & input);
|
||||||
|
|
||||||
|
|
|
@ -130,10 +130,12 @@ std::string common_sampler_params::print() const {
|
||||||
|
|
||||||
snprintf(result, sizeof(result),
|
snprintf(result, sizeof(result),
|
||||||
"\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
|
"\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
|
||||||
"\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, temp = %.3f\n"
|
"\tdry_multiplier = %.3f, dry_base = %.3f, dry_allowed_length = %d, dry_penalty_last_n = %d\n"
|
||||||
|
"\ttop_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, temp = %.3f\n"
|
||||||
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
|
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
|
||||||
penalty_last_n, penalty_repeat, penalty_freq, penalty_present,
|
penalty_last_n, penalty_repeat, penalty_freq, penalty_present,
|
||||||
top_k, tfs_z, top_p, min_p, xtc_probability, xtc_threshold, typ_p, temp,
|
dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n,
|
||||||
|
top_k, top_p, min_p, xtc_probability, xtc_threshold, typ_p, temp,
|
||||||
mirostat, mirostat_eta, mirostat_tau);
|
mirostat, mirostat_eta, mirostat_tau);
|
||||||
|
|
||||||
return std::string(result);
|
return std::string(result);
|
||||||
|
@ -174,6 +176,17 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
||||||
if (params.mirostat == 0) {
|
if (params.mirostat == 0) {
|
||||||
for (const auto & cnstr : params.samplers) {
|
for (const auto & cnstr : params.samplers) {
|
||||||
switch (cnstr) {
|
switch (cnstr) {
|
||||||
|
case COMMON_SAMPLER_TYPE_DRY:
|
||||||
|
{
|
||||||
|
std::vector<const char*> c_breakers;
|
||||||
|
c_breakers.reserve(params.dry_sequence_breakers.size());
|
||||||
|
for (const auto& str : params.dry_sequence_breakers) {
|
||||||
|
c_breakers.push_back(str.c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_sampler_chain_add(result->chain, llama_sampler_init_dry (model, params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
|
||||||
|
}
|
||||||
|
break;
|
||||||
case COMMON_SAMPLER_TYPE_TOP_K:
|
case COMMON_SAMPLER_TYPE_TOP_K:
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
|
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
|
||||||
break;
|
break;
|
||||||
|
@ -186,9 +199,6 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
||||||
case COMMON_SAMPLER_TYPE_XTC:
|
case COMMON_SAMPLER_TYPE_XTC:
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
|
llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
|
||||||
break;
|
break;
|
||||||
case COMMON_SAMPLER_TYPE_TFS_Z:
|
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_tail_free(params.tfs_z, params.min_keep));
|
|
||||||
break;
|
|
||||||
case COMMON_SAMPLER_TYPE_TYPICAL_P:
|
case COMMON_SAMPLER_TYPE_TYPICAL_P:
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep));
|
llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep));
|
||||||
break;
|
break;
|
||||||
|
@ -358,8 +368,8 @@ std::string common_sampler_prev_str(common_sampler * gsmpl, llama_context * ctx_
|
||||||
|
|
||||||
char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
|
char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
|
||||||
switch (cnstr) {
|
switch (cnstr) {
|
||||||
|
case COMMON_SAMPLER_TYPE_DRY: return 'd';
|
||||||
case COMMON_SAMPLER_TYPE_TOP_K: return 'k';
|
case COMMON_SAMPLER_TYPE_TOP_K: return 'k';
|
||||||
case COMMON_SAMPLER_TYPE_TFS_Z: return 'f';
|
|
||||||
case COMMON_SAMPLER_TYPE_TYPICAL_P: return 'y';
|
case COMMON_SAMPLER_TYPE_TYPICAL_P: return 'y';
|
||||||
case COMMON_SAMPLER_TYPE_TOP_P: return 'p';
|
case COMMON_SAMPLER_TYPE_TOP_P: return 'p';
|
||||||
case COMMON_SAMPLER_TYPE_MIN_P: return 'm';
|
case COMMON_SAMPLER_TYPE_MIN_P: return 'm';
|
||||||
|
@ -372,8 +382,8 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
|
||||||
|
|
||||||
std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
|
std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
|
||||||
switch (cnstr) {
|
switch (cnstr) {
|
||||||
|
case COMMON_SAMPLER_TYPE_DRY: return "dry";
|
||||||
case COMMON_SAMPLER_TYPE_TOP_K: return "top_k";
|
case COMMON_SAMPLER_TYPE_TOP_K: return "top_k";
|
||||||
case COMMON_SAMPLER_TYPE_TFS_Z: return "tfs_z";
|
|
||||||
case COMMON_SAMPLER_TYPE_TYPICAL_P: return "typ_p";
|
case COMMON_SAMPLER_TYPE_TYPICAL_P: return "typ_p";
|
||||||
case COMMON_SAMPLER_TYPE_TOP_P: return "top_p";
|
case COMMON_SAMPLER_TYPE_TOP_P: return "top_p";
|
||||||
case COMMON_SAMPLER_TYPE_MIN_P: return "min_p";
|
case COMMON_SAMPLER_TYPE_MIN_P: return "min_p";
|
||||||
|
@ -386,11 +396,11 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
|
||||||
|
|
||||||
std::vector<common_sampler_type> common_sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names) {
|
std::vector<common_sampler_type> common_sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names) {
|
||||||
std::unordered_map<std::string, common_sampler_type> sampler_canonical_name_map {
|
std::unordered_map<std::string, common_sampler_type> sampler_canonical_name_map {
|
||||||
|
{ "dry", COMMON_SAMPLER_TYPE_DRY },
|
||||||
{ "top_k", COMMON_SAMPLER_TYPE_TOP_K },
|
{ "top_k", COMMON_SAMPLER_TYPE_TOP_K },
|
||||||
{ "top_p", COMMON_SAMPLER_TYPE_TOP_P },
|
{ "top_p", COMMON_SAMPLER_TYPE_TOP_P },
|
||||||
{ "typ_p", COMMON_SAMPLER_TYPE_TYPICAL_P },
|
{ "typ_p", COMMON_SAMPLER_TYPE_TYPICAL_P },
|
||||||
{ "min_p", COMMON_SAMPLER_TYPE_MIN_P },
|
{ "min_p", COMMON_SAMPLER_TYPE_MIN_P },
|
||||||
{ "tfs_z", COMMON_SAMPLER_TYPE_TFS_Z },
|
|
||||||
{ "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE },
|
{ "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE },
|
||||||
{ "xtc", COMMON_SAMPLER_TYPE_XTC },
|
{ "xtc", COMMON_SAMPLER_TYPE_XTC },
|
||||||
{ "infill", COMMON_SAMPLER_TYPE_INFILL },
|
{ "infill", COMMON_SAMPLER_TYPE_INFILL },
|
||||||
|
@ -407,8 +417,6 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
|
||||||
{ "typ-p", COMMON_SAMPLER_TYPE_TYPICAL_P },
|
{ "typ-p", COMMON_SAMPLER_TYPE_TYPICAL_P },
|
||||||
{ "typ", COMMON_SAMPLER_TYPE_TYPICAL_P },
|
{ "typ", COMMON_SAMPLER_TYPE_TYPICAL_P },
|
||||||
{ "min-p", COMMON_SAMPLER_TYPE_MIN_P },
|
{ "min-p", COMMON_SAMPLER_TYPE_MIN_P },
|
||||||
{ "tfs-z", COMMON_SAMPLER_TYPE_TFS_Z },
|
|
||||||
{ "tfs", COMMON_SAMPLER_TYPE_TFS_Z },
|
|
||||||
{ "temp", COMMON_SAMPLER_TYPE_TEMPERATURE },
|
{ "temp", COMMON_SAMPLER_TYPE_TEMPERATURE },
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -434,8 +442,8 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
|
||||||
|
|
||||||
std::vector<common_sampler_type> common_sampler_types_from_chars(const std::string & chars) {
|
std::vector<common_sampler_type> common_sampler_types_from_chars(const std::string & chars) {
|
||||||
std::unordered_map<char, common_sampler_type> sampler_name_map = {
|
std::unordered_map<char, common_sampler_type> sampler_name_map = {
|
||||||
|
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_DRY), COMMON_SAMPLER_TYPE_DRY },
|
||||||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_K), COMMON_SAMPLER_TYPE_TOP_K },
|
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_K), COMMON_SAMPLER_TYPE_TOP_K },
|
||||||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TFS_Z), COMMON_SAMPLER_TYPE_TFS_Z },
|
|
||||||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TYPICAL_P), COMMON_SAMPLER_TYPE_TYPICAL_P },
|
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TYPICAL_P), COMMON_SAMPLER_TYPE_TYPICAL_P },
|
||||||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_P), COMMON_SAMPLER_TYPE_TOP_P },
|
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_P), COMMON_SAMPLER_TYPE_TOP_P },
|
||||||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_MIN_P), COMMON_SAMPLER_TYPE_MIN_P },
|
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_MIN_P), COMMON_SAMPLER_TYPE_MIN_P },
|
||||||
|
|
|
@ -573,6 +573,9 @@ class Model:
|
||||||
if chkhsh == "0876d13b50744004aa9aeae05e7b0647eac9d801b5ba4668afc01e709c15e19f":
|
if chkhsh == "0876d13b50744004aa9aeae05e7b0647eac9d801b5ba4668afc01e709c15e19f":
|
||||||
# ref: https://huggingface.co/BAAI/bge-small-en-v1.5
|
# ref: https://huggingface.co/BAAI/bge-small-en-v1.5
|
||||||
res = "bert-bge"
|
res = "bert-bge"
|
||||||
|
if chkhsh == "8e62295832751ca1e8f92f2226f403dea30dc5165e448b5bfa05af5340c64ec7":
|
||||||
|
# ref: https://huggingface.co/BAAI/bge-large-zh-v1.5
|
||||||
|
res = "bert-bge-large"
|
||||||
if chkhsh == "b6dc8df998e1cfbdc4eac8243701a65afe638679230920b50d6f17d81c098166":
|
if chkhsh == "b6dc8df998e1cfbdc4eac8243701a65afe638679230920b50d6f17d81c098166":
|
||||||
# ref: https://huggingface.co/mosaicml/mpt-7b
|
# ref: https://huggingface.co/mosaicml/mpt-7b
|
||||||
res = "mpt"
|
res = "mpt"
|
||||||
|
|
|
@ -72,6 +72,7 @@ models = [
|
||||||
{"name": "deepseek-coder", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/deepseek-ai/deepseek-coder-6.7b-base", },
|
{"name": "deepseek-coder", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/deepseek-ai/deepseek-coder-6.7b-base", },
|
||||||
{"name": "falcon", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/falcon-7b", },
|
{"name": "falcon", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/falcon-7b", },
|
||||||
{"name": "bert-bge", "tokt": TOKENIZER_TYPE.WPM, "repo": "https://huggingface.co/BAAI/bge-small-en-v1.5", },
|
{"name": "bert-bge", "tokt": TOKENIZER_TYPE.WPM, "repo": "https://huggingface.co/BAAI/bge-small-en-v1.5", },
|
||||||
|
{"name": "bert-bge-large", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/BAAI/bge-large-zh-v1.5", },
|
||||||
{"name": "mpt", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/mosaicml/mpt-7b", },
|
{"name": "mpt", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/mosaicml/mpt-7b", },
|
||||||
{"name": "starcoder", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/bigcode/starcoder2-3b", },
|
{"name": "starcoder", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/bigcode/starcoder2-3b", },
|
||||||
{"name": "gpt-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/openai-community/gpt2", },
|
{"name": "gpt-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/openai-community/gpt2", },
|
||||||
|
|
|
@ -40,12 +40,15 @@
|
||||||
repeat_last_n: 0, // 0 = disable penalty, -1 = context size
|
repeat_last_n: 0, // 0 = disable penalty, -1 = context size
|
||||||
repeat_penalty: 1.0, // 1.0 = disabled
|
repeat_penalty: 1.0, // 1.0 = disabled
|
||||||
penalize_nl: false, // true only useful for infinite completion
|
penalize_nl: false, // true only useful for infinite completion
|
||||||
|
dry_multiplier: 0.0, // 0.0 = disabled, 0.8 works well
|
||||||
|
dry_base: 1.75, // 0.0 = disabled
|
||||||
|
dry_allowed_length: 2, // tokens extending repetitions beyond this receive penalty, 2 works well
|
||||||
|
dry_penalty_last_n: -1, // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size)
|
||||||
top_k: 0, // <= 0 to use vocab size
|
top_k: 0, // <= 0 to use vocab size
|
||||||
top_p: 1.0, // 1.0 = disabled
|
top_p: 1.0, // 1.0 = disabled
|
||||||
min_p: 0.05, // 0 = disabled; recommended for non-english: ~ 0.4
|
min_p: 0.05, // 0 = disabled; recommended for non-english: ~ 0.4
|
||||||
xtc_probability: 0.0, // 0 = disabled;
|
xtc_probability: 0.0, // 0 = disabled;
|
||||||
xtc_threshold: 0.1, // > 0.5 disables XTC;
|
xtc_threshold: 0.1, // > 0.5 disables XTC;
|
||||||
tfs_z: 1.0, // 1.0 = disabled
|
|
||||||
typical_p: 1.0, // 1.0 = disabled
|
typical_p: 1.0, // 1.0 = disabled
|
||||||
presence_penalty: 0.0, // 0.0 = disabled
|
presence_penalty: 0.0, // 0.0 = disabled
|
||||||
frequency_penalty: 0.0, // 0.0 = disabled
|
frequency_penalty: 0.0, // 0.0 = disabled
|
||||||
|
@ -833,13 +836,16 @@ return html`
|
||||||
<fieldset class="params">
|
<fieldset class="params">
|
||||||
${IntField({ label: "Top-K", title: "Limits the selection of the next token to the K most probable tokens. 1 means no randomness = greedy sampling. If set to 0, it means the entire vocabulary size is considered.", max: 100, min: 0, step: 1, name: "top_k", value: params.value.top_k })}
|
${IntField({ label: "Top-K", title: "Limits the selection of the next token to the K most probable tokens. 1 means no randomness = greedy sampling. If set to 0, it means the entire vocabulary size is considered.", max: 100, min: 0, step: 1, name: "top_k", value: params.value.top_k })}
|
||||||
${IntField({ label: "Penalize Last N", title: "The last n tokens that are taken into account to penalise repetitions. A value of 0 means that this function is deactivated and -1 means that the entire size of the context is taken into account.", max: 2048, min: 0, step: 16, name: "repeat_last_n", value: params.value.repeat_last_n })}
|
${IntField({ label: "Penalize Last N", title: "The last n tokens that are taken into account to penalise repetitions. A value of 0 means that this function is deactivated and -1 means that the entire size of the context is taken into account.", max: 2048, min: 0, step: 16, name: "repeat_last_n", value: params.value.repeat_last_n })}
|
||||||
${FloatField({ label: "Top-P", title: "Limits the selection of the next token to a subset of tokens whose combined probability reaches a threshold value P = top-P. If set to 1, it means the entire vocabulary size is considered.", max: 1.0, min: 0.0, name: "top_p", step: 0.01, value: params.value.top_p })}
|
|
||||||
${FloatField({ label: "Presence Penalty", title: "A penalty that is applied if certain tokens appear repeatedly in the generated text. A higher value leads to fewer repetitions.", max: 1.0, min: 0.0, name: "presence_penalty", step: 0.01, value: params.value.presence_penalty })}
|
${FloatField({ label: "Presence Penalty", title: "A penalty that is applied if certain tokens appear repeatedly in the generated text. A higher value leads to fewer repetitions.", max: 1.0, min: 0.0, name: "presence_penalty", step: 0.01, value: params.value.presence_penalty })}
|
||||||
${FloatField({ label: "TFS-Z", title: "Activates tail-free sampling, a method used to limit the prediction of tokens that are too frequent. The parameter z controls the strength of this limitation. A value of 1.0 means that this function is deactivated.", max: 1.0, min: 0.0, name: "tfs_z", step: 0.01, value: params.value.tfs_z })}
|
|
||||||
${FloatField({ label: "Frequency Penalty", title: "A penalty that is applied based on the frequency with which certain tokens occur in the training data set. A higher value results in rare tokens being favoured.", max: 1.0, min: 0.0, name: "frequency_penalty", step: 0.01, value: params.value.frequency_penalty })}
|
${FloatField({ label: "Frequency Penalty", title: "A penalty that is applied based on the frequency with which certain tokens occur in the training data set. A higher value results in rare tokens being favoured.", max: 1.0, min: 0.0, name: "frequency_penalty", step: 0.01, value: params.value.frequency_penalty })}
|
||||||
|
${FloatField({ label: "Top-P", title: "Limits the selection of the next token to a subset of tokens whose combined probability reaches a threshold value P = top-P. If set to 1, it means the entire vocabulary size is considered.", max: 1.0, min: 0.0, name: "top_p", step: 0.01, value: params.value.top_p })}
|
||||||
${FloatField({ label: "Typical-P", title: "Activates local typical sampling, a method used to limit the prediction of tokens that are atypical in the current context. The parameter p controls the strength of this limitation. A value of 1.0 means that this function is deactivated.", max: 1.0, min: 0.0, name: "typical_p", step: 0.01, value: params.value.typical_p })}
|
${FloatField({ label: "Typical-P", title: "Activates local typical sampling, a method used to limit the prediction of tokens that are atypical in the current context. The parameter p controls the strength of this limitation. A value of 1.0 means that this function is deactivated.", max: 1.0, min: 0.0, name: "typical_p", step: 0.01, value: params.value.typical_p })}
|
||||||
${FloatField({ label: "XTC probability", title: "Sets the chance for token removal (checked once on sampler start)", max: 1.0, min: 0.0, name: "xtc_probability", step: 0.01, value: params.value.xtc_probability })}
|
${FloatField({ label: "XTC probability", title: "Sets the chance for token removal (checked once on sampler start)", max: 1.0, min: 0.0, name: "xtc_probability", step: 0.01, value: params.value.xtc_probability })}
|
||||||
${FloatField({ label: "XTC threshold", title: "Sets a minimum probability threshold for tokens to be removed", max: 0.5, min: 0.0, name: "xtc_threshold", step: 0.01, value: params.value.xtc_threshold })}
|
${FloatField({ label: "XTC threshold", title: "Sets a minimum probability threshold for tokens to be removed", max: 0.5, min: 0.0, name: "xtc_threshold", step: 0.01, value: params.value.xtc_threshold })}
|
||||||
|
${FloatField({ label: "DRY Penalty Multiplier", title: "Set the DRY repetition penalty multiplier. Default is 0.0, which disables DRY.", max: 5.0, min: 0.0, name: "dry_multiplier", step: 0.01, value: params.value.dry_multiplier })}
|
||||||
|
${FloatField({ label: "DRY Base", title: "Set the DRY repetition penalty base value. Default is 1.75", max: 3.0, min: 1.0, name: "dry_base", step: 0.01, value: params.value.dry_base })}
|
||||||
|
${IntField({ label: "DRY Allowed Length", title: "Tokens that extend repetition beyond this receive exponentially increasing penalty. Default is 2", max: 10, min: 1, step: 1, name: "dry_allowed_length", value: params.value.dry_allowed_length })}
|
||||||
|
${IntField({ label: "DRY Penalty Last N", title: "How many tokens to scan for repetitions. Default is -1, where 0 is disabled and -1 is context size", max: 2048, min: -1, step: 16, name: "dry_penalty_last_n", value: params.value.dry_penalty_last_n })}
|
||||||
${IntField({ label: "Min Keep", title: "If greater than 0, samplers are forced to return N possible tokens at minimum. Default is 0", max: 10, min: 0, name: "min_keep", value: params.value.min_keep })}
|
${IntField({ label: "Min Keep", title: "If greater than 0, samplers are forced to return N possible tokens at minimum. Default is 0", max: 10, min: 0, name: "min_keep", value: params.value.min_keep })}
|
||||||
</fieldset>
|
</fieldset>
|
||||||
|
|
||||||
|
@ -1139,11 +1145,12 @@ document.addEventListener('DOMContentLoaded', (event) => {
|
||||||
xtc_probability: { snapValue: 0.0, snapRangeMultiplier: 4 },
|
xtc_probability: { snapValue: 0.0, snapRangeMultiplier: 4 },
|
||||||
xtc_threshold: { snapValue: 0.5, snapRangeMultiplier: 4 },
|
xtc_threshold: { snapValue: 0.5, snapRangeMultiplier: 4 },
|
||||||
top_p: { snapValue: 1.0, snapRangeMultiplier: 4 },
|
top_p: { snapValue: 1.0, snapRangeMultiplier: 4 },
|
||||||
tfs_z: { snapValue: 1.0, snapRangeMultiplier: 4 },
|
|
||||||
typical_p: { snapValue: 1.0, snapRangeMultiplier: 4 },
|
typical_p: { snapValue: 1.0, snapRangeMultiplier: 4 },
|
||||||
repeat_penalty: { snapValue: 1.0, snapRangeMultiplier: 4 },
|
repeat_penalty: { snapValue: 1.0, snapRangeMultiplier: 4 },
|
||||||
presence_penalty: { snapValue: 0.0, snapRangeMultiplier: 4 },
|
presence_penalty: { snapValue: 0.0, snapRangeMultiplier: 4 },
|
||||||
frequency_penalty: { snapValue: 0.0, snapRangeMultiplier: 4 },
|
frequency_penalty: { snapValue: 0.0, snapRangeMultiplier: 4 },
|
||||||
|
dry_multiplier: { snapValue: 0.0, snapRangeMultiplier: 4 },
|
||||||
|
dry_base: { snapValue: 1.75, snapRangeMultiplier: 4 },
|
||||||
};
|
};
|
||||||
// add an event listener for each slider
|
// add an event listener for each slider
|
||||||
Object.keys(snapSettings).forEach(sliderName => {
|
Object.keys(snapSettings).forEach(sliderName => {
|
||||||
|
|
|
@ -304,12 +304,15 @@
|
||||||
repeat_last_n: 256, // 0 = disable penalty, -1 = context size
|
repeat_last_n: 256, // 0 = disable penalty, -1 = context size
|
||||||
repeat_penalty: 1.18, // 1.0 = disabled
|
repeat_penalty: 1.18, // 1.0 = disabled
|
||||||
penalize_nl: false,
|
penalize_nl: false,
|
||||||
|
dry_multiplier: 0.0, // 0.0 = disabled, 0.8 works well
|
||||||
|
dry_base: 1.75, // 0.0 = disabled
|
||||||
|
dry_allowed_length: 2, // tokens extending repetitions beyond this receive penalty, 2 works well
|
||||||
|
dry_penalty_last_n: -1, // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size)
|
||||||
top_k: 40, // <= 0 to use vocab size
|
top_k: 40, // <= 0 to use vocab size
|
||||||
top_p: 0.95, // 1.0 = disabled
|
top_p: 0.95, // 1.0 = disabled
|
||||||
min_p: 0.05, // 0 = disabled
|
min_p: 0.05, // 0 = disabled
|
||||||
xtc_probability: 0.0, // 0 = disabled;
|
xtc_probability: 0.0, // 0 = disabled;
|
||||||
xtc_threshold: 0.1, // > 0.5 disables XTC;
|
xtc_threshold: 0.1, // > 0.5 disables XTC;
|
||||||
tfs_z: 1.0, // 1.0 = disabled
|
|
||||||
typical_p: 1.0, // 1.0 = disabled
|
typical_p: 1.0, // 1.0 = disabled
|
||||||
presence_penalty: 0.0, // 0.0 = disabled
|
presence_penalty: 0.0, // 0.0 = disabled
|
||||||
frequency_penalty: 0.0, // 0.0 = disabled
|
frequency_penalty: 0.0, // 0.0 = disabled
|
||||||
|
@ -1011,10 +1014,13 @@
|
||||||
<details>
|
<details>
|
||||||
<summary>More options</summary>
|
<summary>More options</summary>
|
||||||
<fieldset class="two">
|
<fieldset class="two">
|
||||||
${FloatField({ label: "TFS-Z", max: 1.0, min: 0.0, name: "tfs_z", step: 0.01, value: params.value.tfs_z })}
|
|
||||||
${FloatField({ label: "Typical P", max: 1.0, min: 0.0, name: "typical_p", step: 0.01, value: params.value.typical_p })}
|
${FloatField({ label: "Typical P", max: 1.0, min: 0.0, name: "typical_p", step: 0.01, value: params.value.typical_p })}
|
||||||
${FloatField({ label: "Presence penalty", max: 1.0, min: 0.0, name: "presence_penalty", step: 0.01, value: params.value.presence_penalty })}
|
${FloatField({ label: "Presence penalty", max: 1.0, min: 0.0, name: "presence_penalty", step: 0.01, value: params.value.presence_penalty })}
|
||||||
${FloatField({ label: "Frequency penalty", max: 1.0, min: 0.0, name: "frequency_penalty", step: 0.01, value: params.value.frequency_penalty })}
|
${FloatField({ label: "Frequency penalty", max: 1.0, min: 0.0, name: "frequency_penalty", step: 0.01, value: params.value.frequency_penalty })}
|
||||||
|
${FloatField({ label: "DRY Penalty Multiplier", max: 5.0, min: 0.0, name: "dry_multiplier", step: 0.01, value: params.value.dry_multiplier })}
|
||||||
|
${FloatField({ label: "DRY Base", max: 3.0, min: 1.0, name: "dry_base", step: 0.01, value: params.value.dry_base })}
|
||||||
|
${IntField({ label: "DRY Allowed Length", max: 10, min: 2, step: 1, name: "dry_allowed_length", value: params.value.dry_allowed_length })}
|
||||||
|
${IntField({ label: "DRY Penalty Last N", max: 2048, min: -1, step: 16, name: "dry_penalty_last_n", value: params.value.dry_penalty_last_n })}
|
||||||
${FloatField({ label: "XTC probability", max: 1.0, min: 0.0, name: "xtc_probability", step: 0.01, value: params.value.xtc_probability })}
|
${FloatField({ label: "XTC probability", max: 1.0, min: 0.0, name: "xtc_probability", step: 0.01, value: params.value.xtc_probability })}
|
||||||
${FloatField({ label: "XTC threshold", max: 0.5, min: 0.0, name: "xtc_threshold", step: 0.01, value: params.value.xtc_threshold })}
|
${FloatField({ label: "XTC threshold", max: 0.5, min: 0.0, name: "xtc_threshold", step: 0.01, value: params.value.xtc_threshold })}
|
||||||
</fieldset>
|
</fieldset>
|
||||||
|
|
0
examples/server/public/style.css
Executable file → Normal file
0
examples/server/public/style.css
Executable file → Normal file
|
@ -44,21 +44,6 @@
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
|
|
||||||
#define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
|
|
||||||
#define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
|
|
||||||
#define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
|
|
||||||
#define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
|
|
||||||
|
|
||||||
#define SRV_INF(fmt, ...) LOG_INF("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
|
||||||
#define SRV_WRN(fmt, ...) LOG_WRN("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
|
||||||
#define SRV_ERR(fmt, ...) LOG_ERR("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
|
||||||
#define SRV_DBG(fmt, ...) LOG_DBG("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
|
||||||
|
|
||||||
#define QUE_INF(fmt, ...) LOG_INF("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
|
||||||
#define QUE_WRN(fmt, ...) LOG_WRN("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
|
||||||
#define QUE_ERR(fmt, ...) LOG_ERR("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
|
||||||
#define QUE_DBG(fmt, ...) LOG_DBG("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
|
||||||
|
|
||||||
using json = nlohmann::ordered_json;
|
using json = nlohmann::ordered_json;
|
||||||
|
|
||||||
enum stop_type {
|
enum stop_type {
|
||||||
|
@ -69,6 +54,7 @@ enum stop_type {
|
||||||
// state diagram: https://github.com/ggerganov/llama.cpp/pull/9283
|
// state diagram: https://github.com/ggerganov/llama.cpp/pull/9283
|
||||||
enum slot_state {
|
enum slot_state {
|
||||||
SLOT_STATE_IDLE,
|
SLOT_STATE_IDLE,
|
||||||
|
SLOT_STATE_STARTED, // TODO: this state is only used for setting up the initial prompt processing; maybe merge it with launch_slot_with_task in the future
|
||||||
SLOT_STATE_PROCESSING_PROMPT,
|
SLOT_STATE_PROCESSING_PROMPT,
|
||||||
SLOT_STATE_DONE_PROMPT,
|
SLOT_STATE_DONE_PROMPT,
|
||||||
SLOT_STATE_GENERATING,
|
SLOT_STATE_GENERATING,
|
||||||
|
@ -80,7 +66,7 @@ enum server_state {
|
||||||
};
|
};
|
||||||
|
|
||||||
enum server_task_type {
|
enum server_task_type {
|
||||||
SERVER_TASK_TYPE_COMPLETION,
|
SERVER_TASK_TYPE_INFERENCE,
|
||||||
SERVER_TASK_TYPE_CANCEL,
|
SERVER_TASK_TYPE_CANCEL,
|
||||||
SERVER_TASK_TYPE_NEXT_RESPONSE,
|
SERVER_TASK_TYPE_NEXT_RESPONSE,
|
||||||
SERVER_TASK_TYPE_METRICS,
|
SERVER_TASK_TYPE_METRICS,
|
||||||
|
@ -90,21 +76,22 @@ enum server_task_type {
|
||||||
SERVER_TASK_TYPE_SET_LORA,
|
SERVER_TASK_TYPE_SET_LORA,
|
||||||
};
|
};
|
||||||
|
|
||||||
enum server_task_cmpl_type {
|
enum server_task_inf_type {
|
||||||
SERVER_TASK_CMPL_TYPE_NORMAL,
|
SERVER_TASK_INF_TYPE_COMPLETION,
|
||||||
SERVER_TASK_CMPL_TYPE_EMBEDDING,
|
SERVER_TASK_INF_TYPE_EMBEDDING,
|
||||||
SERVER_TASK_CMPL_TYPE_RERANK,
|
SERVER_TASK_INF_TYPE_RERANK,
|
||||||
SERVER_TASK_CMPL_TYPE_INFILL,
|
SERVER_TASK_INF_TYPE_INFILL,
|
||||||
};
|
};
|
||||||
|
|
||||||
struct server_task {
|
struct server_task {
|
||||||
int id = -1; // to be filled by server_queue
|
int id = -1; // to be filled by server_queue
|
||||||
int id_target = -1; // used by SERVER_TASK_TYPE_CANCEL
|
int id_target = -1; // used by SERVER_TASK_TYPE_CANCEL
|
||||||
|
|
||||||
|
llama_tokens prompt_tokens;
|
||||||
server_task_type type;
|
server_task_type type;
|
||||||
json data;
|
json data;
|
||||||
|
|
||||||
server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL;
|
server_task_inf_type inf_type = SERVER_TASK_INF_TYPE_COMPLETION;
|
||||||
|
|
||||||
// utility function
|
// utility function
|
||||||
static std::unordered_set<int> get_list_id(const std::vector<server_task> & tasks) {
|
static std::unordered_set<int> get_list_id(const std::vector<server_task> & tasks) {
|
||||||
|
@ -162,26 +149,20 @@ struct server_slot {
|
||||||
int32_t i_batch = -1;
|
int32_t i_batch = -1;
|
||||||
int32_t n_predict = -1; // TODO: disambiguate from params.n_predict
|
int32_t n_predict = -1; // TODO: disambiguate from params.n_predict
|
||||||
|
|
||||||
|
// n_prompt_tokens may not be equal to prompt_tokens.size(), because prompt maybe truncated
|
||||||
int32_t n_prompt_tokens = 0;
|
int32_t n_prompt_tokens = 0;
|
||||||
int32_t n_prompt_tokens_processed = 0;
|
int32_t n_prompt_tokens_processed = 0;
|
||||||
|
|
||||||
json prompt; // can be either a string, array of strings or array of token ids
|
// input prompt tokens
|
||||||
|
llama_tokens prompt_tokens;
|
||||||
json input_prefix;
|
|
||||||
json input_suffix;
|
|
||||||
json input_extra;
|
|
||||||
|
|
||||||
// when a task is submitted, we first tokenize the prompt and store it here
|
|
||||||
std::vector<llama_token> prompt_tokens;
|
|
||||||
std::vector<llama_token> extra_tokens;
|
|
||||||
|
|
||||||
size_t last_nl_pos = 0;
|
size_t last_nl_pos = 0;
|
||||||
|
|
||||||
std::string generated_text;
|
std::string generated_text;
|
||||||
std::vector<llama_token> cache_tokens;
|
llama_tokens cache_tokens;
|
||||||
std::vector<completion_token_output> generated_token_probs;
|
std::vector<completion_token_output> generated_token_probs;
|
||||||
|
|
||||||
server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL;
|
server_task_inf_type inf_type = SERVER_TASK_INF_TYPE_COMPLETION;
|
||||||
|
|
||||||
bool has_next_token = true;
|
bool has_next_token = true;
|
||||||
bool has_new_line = false;
|
bool has_new_line = false;
|
||||||
|
@ -230,7 +211,7 @@ struct server_slot {
|
||||||
n_past = 0;
|
n_past = 0;
|
||||||
n_sent_text = 0;
|
n_sent_text = 0;
|
||||||
n_sent_token_probs = 0;
|
n_sent_token_probs = 0;
|
||||||
cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL;
|
inf_type = SERVER_TASK_INF_TYPE_COMPLETION;
|
||||||
|
|
||||||
generated_token_probs.clear();
|
generated_token_probs.clear();
|
||||||
}
|
}
|
||||||
|
@ -735,42 +716,6 @@ struct server_context {
|
||||||
metrics.init();
|
metrics.init();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<llama_token> tokenize(const json & json_prompt, bool add_special, bool parse_special) const {
|
|
||||||
// If `add_bos` is true, we only add BOS, when json_prompt is a string,
|
|
||||||
// or the first element of the json_prompt array is a string.
|
|
||||||
std::vector<llama_token> prompt_tokens;
|
|
||||||
|
|
||||||
if (json_prompt.is_array()) {
|
|
||||||
bool first = true;
|
|
||||||
for (const auto & p : json_prompt) {
|
|
||||||
if (p.is_string()) {
|
|
||||||
auto s = p.template get<std::string>();
|
|
||||||
|
|
||||||
std::vector<llama_token> p;
|
|
||||||
if (first) {
|
|
||||||
p = common_tokenize(ctx, s, add_special, parse_special);
|
|
||||||
first = false;
|
|
||||||
} else {
|
|
||||||
p = common_tokenize(ctx, s, false, parse_special);
|
|
||||||
}
|
|
||||||
|
|
||||||
prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end());
|
|
||||||
} else {
|
|
||||||
if (first) {
|
|
||||||
first = false;
|
|
||||||
}
|
|
||||||
|
|
||||||
prompt_tokens.push_back(p.template get<llama_token>());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
auto s = json_prompt.template get<std::string>();
|
|
||||||
prompt_tokens = common_tokenize(ctx, s, add_special, parse_special);
|
|
||||||
}
|
|
||||||
|
|
||||||
return prompt_tokens;
|
|
||||||
}
|
|
||||||
|
|
||||||
server_slot * get_slot_by_id(int id) {
|
server_slot * get_slot_by_id(int id) {
|
||||||
for (server_slot & slot : slots) {
|
for (server_slot & slot : slots) {
|
||||||
if (slot.id == id) {
|
if (slot.id == id) {
|
||||||
|
@ -795,22 +740,16 @@ struct server_context {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
// skip the slot if it does not contains prompt
|
// skip the slot if it does not contains cached tokens
|
||||||
if (!slot.prompt.is_string()) {
|
if (slot.prompt_tokens.empty()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
// current slot's prompt
|
|
||||||
std::string slot_prompt = slot.prompt.get<std::string>();
|
|
||||||
|
|
||||||
// length of the current slot's prompt
|
|
||||||
int slot_prompt_len = slot_prompt.size();
|
|
||||||
|
|
||||||
// length of the Longest Common Prefix between the current slot's prompt and the input prompt
|
// length of the Longest Common Prefix between the current slot's prompt and the input prompt
|
||||||
int lcp_len = longest_common_prefix(slot_prompt, prompt);
|
int lcp_len = longest_common_prefix(slot.cache_tokens, slot.prompt_tokens);
|
||||||
|
|
||||||
// fraction of the common substring length compared to the current slot's prompt length
|
// fraction of the common substring length compared to the current slot's prompt length
|
||||||
similarity = static_cast<float>(lcp_len) / slot_prompt_len;
|
similarity = static_cast<float>(lcp_len) / static_cast<int>(slot.prompt_tokens.size());
|
||||||
|
|
||||||
// select the current slot if the criteria match
|
// select the current slot if the criteria match
|
||||||
if (lcp_len > max_lcp_len && similarity > slot_prompt_similarity) {
|
if (lcp_len > max_lcp_len && similarity > slot_prompt_similarity) {
|
||||||
|
@ -862,35 +801,57 @@ struct server_context {
|
||||||
slot.oaicompat_model = "";
|
slot.oaicompat_model = "";
|
||||||
}
|
}
|
||||||
|
|
||||||
slot.params.stream = json_value(data, "stream", false);
|
slot.params.stream = json_value(data, "stream", false);
|
||||||
slot.params.cache_prompt = json_value(data, "cache_prompt", false);
|
slot.params.cache_prompt = json_value(data, "cache_prompt", false);
|
||||||
slot.params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", default_params.n_predict));
|
slot.params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", default_params.n_predict));
|
||||||
slot.params.n_indent = json_value(data, "n_indent", default_params.n_indent);
|
slot.params.n_indent = json_value(data, "n_indent", default_params.n_indent);
|
||||||
slot.sparams.top_k = json_value(data, "top_k", default_sparams.top_k);
|
slot.sparams.top_k = json_value(data, "top_k", default_sparams.top_k);
|
||||||
slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p);
|
slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p);
|
||||||
slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p);
|
slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p);
|
||||||
slot.sparams.xtc_probability = json_value(data, "xtc_probability", default_sparams.xtc_probability);
|
slot.sparams.xtc_probability = json_value(data, "xtc_probability", default_sparams.xtc_probability);
|
||||||
slot.sparams.xtc_threshold = json_value(data, "xtc_threshold", default_sparams.xtc_threshold);
|
slot.sparams.xtc_threshold = json_value(data, "xtc_threshold", default_sparams.xtc_threshold);
|
||||||
slot.sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z);
|
slot.sparams.typ_p = json_value(data, "typical_p", default_sparams.typ_p);
|
||||||
slot.sparams.typ_p = json_value(data, "typical_p", default_sparams.typ_p);
|
slot.sparams.temp = json_value(data, "temperature", default_sparams.temp);
|
||||||
slot.sparams.temp = json_value(data, "temperature", default_sparams.temp);
|
slot.sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range);
|
||||||
slot.sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range);
|
slot.sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent);
|
||||||
slot.sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent);
|
slot.sparams.penalty_last_n = json_value(data, "repeat_last_n", default_sparams.penalty_last_n);
|
||||||
slot.sparams.penalty_last_n = json_value(data, "repeat_last_n", default_sparams.penalty_last_n);
|
slot.sparams.penalty_repeat = json_value(data, "repeat_penalty", default_sparams.penalty_repeat);
|
||||||
slot.sparams.penalty_repeat = json_value(data, "repeat_penalty", default_sparams.penalty_repeat);
|
slot.sparams.penalty_freq = json_value(data, "frequency_penalty", default_sparams.penalty_freq);
|
||||||
slot.sparams.penalty_freq = json_value(data, "frequency_penalty", default_sparams.penalty_freq);
|
slot.sparams.penalty_present = json_value(data, "presence_penalty", default_sparams.penalty_present);
|
||||||
slot.sparams.penalty_present = json_value(data, "presence_penalty", default_sparams.penalty_present);
|
slot.sparams.dry_multiplier = json_value(data, "dry_multiplier", default_sparams.dry_multiplier);
|
||||||
slot.sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat);
|
slot.sparams.dry_base = json_value(data, "dry_base", default_sparams.dry_base);
|
||||||
slot.sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau);
|
slot.sparams.dry_allowed_length = json_value(data, "dry_allowed_length", default_sparams.dry_allowed_length);
|
||||||
slot.sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta);
|
slot.sparams.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", default_sparams.dry_penalty_last_n);
|
||||||
slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
|
slot.sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat);
|
||||||
slot.params.n_keep = json_value(data, "n_keep", default_params.n_keep);
|
slot.sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau);
|
||||||
slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard);
|
slot.sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta);
|
||||||
slot.sparams.seed = json_value(data, "seed", default_sparams.seed);
|
slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
|
||||||
slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
|
slot.params.n_keep = json_value(data, "n_keep", default_params.n_keep);
|
||||||
slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
|
slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard);
|
||||||
//slot.params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", default_params.t_max_prompt_ms); // TODO: implement
|
slot.sparams.seed = json_value(data, "seed", default_sparams.seed);
|
||||||
slot.params.t_max_predict_ms = json_value(data, "t_max_predict_ms", default_params.t_max_predict_ms);
|
slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
|
||||||
|
slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
|
||||||
|
//slot.params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", default_params.t_max_prompt_ms); // TODO: implement
|
||||||
|
slot.params.t_max_predict_ms = json_value(data, "t_max_predict_ms", default_params.t_max_predict_ms);
|
||||||
|
|
||||||
|
if (slot.sparams.dry_base < 1.0f)
|
||||||
|
{
|
||||||
|
slot.sparams.dry_base = default_sparams.dry_base;
|
||||||
|
}
|
||||||
|
|
||||||
|
// sequence breakers for DRY
|
||||||
|
{
|
||||||
|
// Currently, this is not compatible with TextGen WebUI, Koboldcpp and SillyTavern format
|
||||||
|
// Ref: https://github.com/oobabooga/text-generation-webui/blob/d1af7a41ade7bd3c3a463bfa640725edb818ebaf/extensions/openai/typing.py#L39
|
||||||
|
|
||||||
|
if (data.contains("dry_sequence_breakers")) {
|
||||||
|
slot.sparams.dry_sequence_breakers = json_value(data, "dry_sequence_breakers", std::vector<std::string>());
|
||||||
|
if (slot.sparams.dry_sequence_breakers.empty()) {
|
||||||
|
send_error(task, "Error: dry_sequence_breakers must be a non-empty array of strings", ERROR_TYPE_INVALID_REQUEST);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// process "json_schema" and "grammar"
|
// process "json_schema" and "grammar"
|
||||||
if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
|
if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
|
||||||
|
@ -915,57 +876,6 @@ struct server_context {
|
||||||
SLT_WRN(slot, "n_predict = %d exceeds server configuration, setting to %d", slot.n_predict, slot.n_predict);
|
SLT_WRN(slot, "n_predict = %d exceeds server configuration, setting to %d", slot.n_predict, slot.n_predict);
|
||||||
}
|
}
|
||||||
|
|
||||||
// infill
|
|
||||||
slot.input_prefix = json_value(data, "input_prefix", json());
|
|
||||||
slot.input_suffix = json_value(data, "input_suffix", json());
|
|
||||||
slot.input_extra = json_value(data, "input_extra", json());
|
|
||||||
|
|
||||||
SLT_DBG(slot, "extra_context chunks: %d\n", (int) slot.input_extra.size());
|
|
||||||
for (const auto & chunk : slot.input_extra) {
|
|
||||||
// { "text": string, "filename": string }
|
|
||||||
if (!chunk.contains("text") || !chunk["text"].is_string()) {
|
|
||||||
send_error(task, "extra_context chunk must contain a \"text\" field with a string value", ERROR_TYPE_INVALID_REQUEST);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// filename is optional
|
|
||||||
if (chunk.contains("filename") && !chunk["filename"].is_string()) {
|
|
||||||
send_error(task, "extra_context chunk's \"filename\" field must be a string", ERROR_TYPE_INVALID_REQUEST);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
SLT_DBG(slot, "extra_context chunk in file '%s':\n%s\n", chunk.value("filename", "").c_str(), chunk.value("text", "").c_str());
|
|
||||||
}
|
|
||||||
|
|
||||||
// get prompt
|
|
||||||
{
|
|
||||||
const auto & prompt = data.find("prompt");
|
|
||||||
if (prompt == data.end()) {
|
|
||||||
send_error(task, "\"prompt\" must be provided", ERROR_TYPE_INVALID_REQUEST);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if ((prompt->is_string()) ||
|
|
||||||
(prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_string()) ||
|
|
||||||
(prompt->is_array() && !prompt->empty() && prompt->at(0).is_number_integer())) {
|
|
||||||
slot.prompt = *prompt;
|
|
||||||
} else if (prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_array()) {
|
|
||||||
slot.prompt = prompt->at(0);
|
|
||||||
} else if (prompt->is_array() && prompt->size() > 1) {
|
|
||||||
// array of strings
|
|
||||||
for (const auto & el : *prompt) {
|
|
||||||
if (!el.is_string()) {
|
|
||||||
send_error(task, "\"prompt\" must be a string, an array of strings or an array of integers", ERROR_TYPE_INVALID_REQUEST);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
slot.prompt = *prompt;
|
|
||||||
} else {
|
|
||||||
send_error(task, "\"prompt\" must be a string, an array of strings or an array of integers", ERROR_TYPE_INVALID_REQUEST);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
{
|
||||||
slot.sparams.logit_bias.clear();
|
slot.sparams.logit_bias.clear();
|
||||||
|
|
||||||
|
@ -1045,8 +955,7 @@ struct server_context {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
slot.state = SLOT_STATE_PROCESSING_PROMPT;
|
slot.state = SLOT_STATE_STARTED;
|
||||||
slot.prompt_tokens.clear();
|
|
||||||
|
|
||||||
SLT_INF(slot, "%s", "processing task\n");
|
SLT_INF(slot, "%s", "processing task\n");
|
||||||
|
|
||||||
|
@ -1240,12 +1149,16 @@ struct server_context {
|
||||||
{"min_p", slot.sparams.min_p},
|
{"min_p", slot.sparams.min_p},
|
||||||
{"xtc_probability", slot.sparams.xtc_probability},
|
{"xtc_probability", slot.sparams.xtc_probability},
|
||||||
{"xtc_threshold", slot.sparams.xtc_threshold},
|
{"xtc_threshold", slot.sparams.xtc_threshold},
|
||||||
{"tfs_z", slot.sparams.tfs_z},
|
|
||||||
{"typical_p", slot.sparams.typ_p},
|
{"typical_p", slot.sparams.typ_p},
|
||||||
{"repeat_last_n", slot.sparams.penalty_last_n},
|
{"repeat_last_n", slot.sparams.penalty_last_n},
|
||||||
{"repeat_penalty", slot.sparams.penalty_repeat},
|
{"repeat_penalty", slot.sparams.penalty_repeat},
|
||||||
{"presence_penalty", slot.sparams.penalty_present},
|
{"presence_penalty", slot.sparams.penalty_present},
|
||||||
{"frequency_penalty", slot.sparams.penalty_freq},
|
{"frequency_penalty", slot.sparams.penalty_freq},
|
||||||
|
{"dry_multiplier", slot.sparams.dry_multiplier},
|
||||||
|
{"dry_base", slot.sparams.dry_base},
|
||||||
|
{"dry_allowed_length", slot.sparams.dry_allowed_length},
|
||||||
|
{"dry_penalty_last_n", slot.sparams.dry_penalty_last_n},
|
||||||
|
{"dry_sequence_breakers", slot.sparams.dry_sequence_breakers},
|
||||||
{"mirostat", slot.sparams.mirostat},
|
{"mirostat", slot.sparams.mirostat},
|
||||||
{"mirostat_tau", slot.sparams.mirostat_tau},
|
{"mirostat_tau", slot.sparams.mirostat_tau},
|
||||||
{"mirostat_eta", slot.sparams.mirostat_eta},
|
{"mirostat_eta", slot.sparams.mirostat_eta},
|
||||||
|
@ -1298,7 +1211,7 @@ struct server_context {
|
||||||
};
|
};
|
||||||
|
|
||||||
if (slot.sparams.n_probs > 0) {
|
if (slot.sparams.n_probs > 0) {
|
||||||
const std::vector<llama_token> to_send_toks = common_tokenize(ctx, tkn.text_to_send, false);
|
const llama_tokens to_send_toks = common_tokenize(ctx, tkn.text_to_send, false);
|
||||||
const size_t probs_pos = std::min(slot.n_sent_token_probs, slot.generated_token_probs.size());
|
const size_t probs_pos = std::min(slot.n_sent_token_probs, slot.generated_token_probs.size());
|
||||||
const size_t probs_stop_pos = std::min(slot.n_sent_token_probs + to_send_toks.size(), slot.generated_token_probs.size());
|
const size_t probs_stop_pos = std::min(slot.n_sent_token_probs + to_send_toks.size(), slot.generated_token_probs.size());
|
||||||
|
|
||||||
|
@ -1334,7 +1247,7 @@ struct server_context {
|
||||||
{"tokens_predicted", slot.n_decoded},
|
{"tokens_predicted", slot.n_decoded},
|
||||||
{"tokens_evaluated", slot.n_prompt_tokens},
|
{"tokens_evaluated", slot.n_prompt_tokens},
|
||||||
{"generation_settings", get_formated_generation(slot)},
|
{"generation_settings", get_formated_generation(slot)},
|
||||||
{"prompt", slot.prompt},
|
{"prompt", common_detokenize(ctx, slot.prompt_tokens)},
|
||||||
{"has_new_line", slot.has_new_line},
|
{"has_new_line", slot.has_new_line},
|
||||||
{"truncated", slot.truncated},
|
{"truncated", slot.truncated},
|
||||||
{"stopped_eos", slot.stopped_eos},
|
{"stopped_eos", slot.stopped_eos},
|
||||||
|
@ -1349,7 +1262,7 @@ struct server_context {
|
||||||
if (slot.sparams.n_probs > 0) {
|
if (slot.sparams.n_probs > 0) {
|
||||||
std::vector<completion_token_output> probs;
|
std::vector<completion_token_output> probs;
|
||||||
if (!slot.params.stream && slot.stopped_word) {
|
if (!slot.params.stream && slot.stopped_word) {
|
||||||
const std::vector<llama_token> stop_word_toks = common_tokenize(ctx, slot.stopping_word, false);
|
const llama_tokens stop_word_toks = common_tokenize(ctx, slot.stopping_word, false);
|
||||||
|
|
||||||
size_t safe_offset = std::min(slot.generated_token_probs.size(), stop_word_toks.size());
|
size_t safe_offset = std::min(slot.generated_token_probs.size(), stop_word_toks.size());
|
||||||
probs = std::vector<completion_token_output>(
|
probs = std::vector<completion_token_output>(
|
||||||
|
@ -1458,19 +1371,17 @@ struct server_context {
|
||||||
// Functions to create new task(s) and receive result(s)
|
// Functions to create new task(s) and receive result(s)
|
||||||
//
|
//
|
||||||
|
|
||||||
std::vector<server_task> create_tasks_cmpl(json data, server_task_cmpl_type cmpl_type) {
|
// break the input "prompt" into multiple tasks if needed, then format and tokenize the input prompt(s)
|
||||||
|
std::vector<server_task> create_tasks_inference(json data, server_task_inf_type inf_type) {
|
||||||
std::vector<server_task> tasks;
|
std::vector<server_task> tasks;
|
||||||
auto create_task = [&](json & task_data, bool replace_prompt, json prompt) {
|
auto create_task = [&](json & task_data, llama_tokens & prompt_tokens) {
|
||||||
|
SRV_DBG("create task, n_tokens = %d\n", (int) prompt_tokens.size());
|
||||||
server_task task;
|
server_task task;
|
||||||
task.id = queue_tasks.get_new_id();
|
task.id = queue_tasks.get_new_id();
|
||||||
task.cmpl_type = cmpl_type;
|
task.inf_type = inf_type;
|
||||||
task.type = SERVER_TASK_TYPE_COMPLETION;
|
task.type = SERVER_TASK_TYPE_INFERENCE;
|
||||||
if (replace_prompt) {
|
task.data = task_data;
|
||||||
task.data = task_data;
|
task.prompt_tokens = std::move(prompt_tokens);
|
||||||
task.data["prompt"] = std::move(prompt);
|
|
||||||
} else {
|
|
||||||
task.data = std::move(task_data);
|
|
||||||
}
|
|
||||||
tasks.push_back(std::move(task));
|
tasks.push_back(std::move(task));
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -1479,41 +1390,49 @@ struct server_context {
|
||||||
throw std::runtime_error(error_msg);
|
throw std::runtime_error(error_msg);
|
||||||
}
|
}
|
||||||
|
|
||||||
json prompt = data.at("prompt");
|
// because llama_tokenize api is thread-safe, we can tokenize the prompt from HTTP thread
|
||||||
|
bool add_special = inf_type != SERVER_TASK_INF_TYPE_RERANK && inf_type != SERVER_TASK_INF_TYPE_INFILL;
|
||||||
// if the prompt is a singleton (i.e. a string or a list of tokens), we only need to create single task
|
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx, data.at("prompt"), add_special, true);
|
||||||
if (prompt.is_string() || json_is_array_of_numbers(prompt)) {
|
switch (inf_type) {
|
||||||
data["index"] = 0;
|
case SERVER_TASK_INF_TYPE_RERANK:
|
||||||
create_task(data, false, nullptr);
|
{
|
||||||
} else if (prompt.is_array()) {
|
// prompts[0] is the question
|
||||||
// otherwise, it's a multiple-prompt task, we break it into smaller tasks
|
// the rest are the answers/documents
|
||||||
std::vector<json> prompts = prompt;
|
GGML_ASSERT(tokenized_prompts.size() > 1);
|
||||||
if (cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
|
SRV_DBG("creating rerank tasks, n_prompts = %d\n", (int) tokenized_prompts.size() - 1);
|
||||||
// prompts[0] is the question
|
for (size_t i = 1; i < tokenized_prompts.size(); i++) {
|
||||||
// the rest are the answers/documents
|
data["index"] = i - 1;
|
||||||
SRV_DBG("creating rerank tasks, n_prompts = %d\n", (int) prompts.size() - 1);
|
auto tokens = format_rerank(model, tokenized_prompts[0], tokenized_prompts[i]);
|
||||||
for (size_t i = 1; i < prompts.size(); i++) {
|
create_task(data, tokens);
|
||||||
json qd;
|
}
|
||||||
qd.push_back(prompts[0]);
|
} break;
|
||||||
qd.push_back(prompts[i]);
|
case SERVER_TASK_INF_TYPE_INFILL:
|
||||||
data["index"] = i - 1;
|
{
|
||||||
create_task(data, true, qd);
|
SRV_DBG("creating infill tasks, n_prompts = %d\n", (int) tokenized_prompts.size());
|
||||||
}
|
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
|
||||||
} else {
|
|
||||||
SRV_DBG("creating multi-prompt tasks, n_prompts = %d\n", (int) prompts.size());
|
|
||||||
for (size_t i = 0; i < prompts.size(); i++) {
|
|
||||||
const auto & e = prompts[i];
|
|
||||||
if (e.is_string() || json_is_array_of_numbers(e)) {
|
|
||||||
data["index"] = i;
|
data["index"] = i;
|
||||||
create_task(data, true, e);
|
auto tokens = format_infill(
|
||||||
} else {
|
ctx,
|
||||||
throw std::runtime_error(error_msg);
|
data.at("input_prefix"),
|
||||||
|
data.at("input_suffix"),
|
||||||
|
data.at("input_extra"),
|
||||||
|
params.n_batch,
|
||||||
|
params.n_predict,
|
||||||
|
slots[0].n_ctx, // TODO: there should be a better way
|
||||||
|
params.spm_infill,
|
||||||
|
tokenized_prompts[i]
|
||||||
|
);
|
||||||
|
create_task(data, tokens);
|
||||||
|
}
|
||||||
|
} break;
|
||||||
|
default:
|
||||||
|
{
|
||||||
|
SRV_DBG("creating multi-prompt tasks, n_prompts = %d\n", (int) tokenized_prompts.size());
|
||||||
|
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
|
||||||
|
data["index"] = i;
|
||||||
|
create_task(data, tokenized_prompts[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// invalid case
|
|
||||||
throw std::runtime_error(error_msg);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return tasks;
|
return tasks;
|
||||||
|
@ -1535,7 +1454,7 @@ struct server_context {
|
||||||
queue_tasks.post(cancel_tasks, true);
|
queue_tasks.post(cancel_tasks, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
// receive the results from task(s) created by create_tasks_cmpl
|
// receive the results from task(s) created by create_tasks_inference
|
||||||
void receive_cmpl_results(
|
void receive_cmpl_results(
|
||||||
const std::unordered_set<int> & id_tasks,
|
const std::unordered_set<int> & id_tasks,
|
||||||
const std::function<void(std::vector<server_task_result>&)> & result_handler,
|
const std::function<void(std::vector<server_task_result>&)> & result_handler,
|
||||||
|
@ -1559,7 +1478,7 @@ struct server_context {
|
||||||
result_handler(results);
|
result_handler(results);
|
||||||
}
|
}
|
||||||
|
|
||||||
// receive the results from task(s) created by create_tasks_cmpl, in stream mode
|
// receive the results from task(s) created by create_tasks_inference, in stream mode
|
||||||
void receive_cmpl_results_stream(
|
void receive_cmpl_results_stream(
|
||||||
const std::unordered_set<int> & id_tasks, const
|
const std::unordered_set<int> & id_tasks, const
|
||||||
std::function<bool(server_task_result&)> & result_handler, const
|
std::function<bool(server_task_result&)> & result_handler, const
|
||||||
|
@ -1592,7 +1511,7 @@ struct server_context {
|
||||||
|
|
||||||
void process_single_task(const server_task & task) {
|
void process_single_task(const server_task & task) {
|
||||||
switch (task.type) {
|
switch (task.type) {
|
||||||
case SERVER_TASK_TYPE_COMPLETION:
|
case SERVER_TASK_TYPE_INFERENCE:
|
||||||
{
|
{
|
||||||
const int id_slot = json_value(task.data, "id_slot", -1);
|
const int id_slot = json_value(task.data, "id_slot", -1);
|
||||||
|
|
||||||
|
@ -1624,9 +1543,10 @@ struct server_context {
|
||||||
|
|
||||||
slot->reset();
|
slot->reset();
|
||||||
|
|
||||||
slot->id_task = task.id;
|
slot->id_task = task.id;
|
||||||
slot->cmpl_type = task.cmpl_type;
|
slot->inf_type = task.inf_type;
|
||||||
slot->index = json_value(task.data, "index", 0);
|
slot->index = json_value(task.data, "index", 0);
|
||||||
|
slot->prompt_tokens = std::move(task.prompt_tokens);
|
||||||
|
|
||||||
if (!launch_slot_with_task(*slot, task)) {
|
if (!launch_slot_with_task(*slot, task)) {
|
||||||
SRV_ERR("failed to launch slot with task, id_task = %d\n", task.id);
|
SRV_ERR("failed to launch slot with task, id_task = %d\n", task.id);
|
||||||
|
@ -1659,7 +1579,7 @@ struct server_context {
|
||||||
slot_data["id"] = slot.id;
|
slot_data["id"] = slot.id;
|
||||||
slot_data["id_task"] = slot.id_task;
|
slot_data["id_task"] = slot.id_task;
|
||||||
slot_data["state"] = slot.state;
|
slot_data["state"] = slot.state;
|
||||||
slot_data["prompt"] = slot.prompt;
|
slot_data["prompt"] = common_detokenize(ctx, slot.prompt_tokens);
|
||||||
slot_data["next_token"] = {
|
slot_data["next_token"] = {
|
||||||
{"has_next_token", slot.has_next_token},
|
{"has_next_token", slot.has_next_token},
|
||||||
{"has_new_line", slot.has_new_line},
|
{"has_new_line", slot.has_new_line},
|
||||||
|
@ -1786,9 +1706,6 @@ struct server_context {
|
||||||
}
|
}
|
||||||
slot->cache_tokens.resize(token_count);
|
slot->cache_tokens.resize(token_count);
|
||||||
|
|
||||||
// TODO: maybe detokenize the slot->cache_tokens instead?
|
|
||||||
slot->prompt = string_format("[restored %d tokens from file]", (int) token_count);
|
|
||||||
|
|
||||||
const int64_t t_end = ggml_time_us();
|
const int64_t t_end = ggml_time_us();
|
||||||
const double t_restore_ms = (t_end - t_start) / 1000.0;
|
const double t_restore_ms = (t_end - t_start) / 1000.0;
|
||||||
|
|
||||||
|
@ -1955,142 +1872,19 @@ struct server_context {
|
||||||
if (params.cont_batching || batch.n_tokens == 0) {
|
if (params.cont_batching || batch.n_tokens == 0) {
|
||||||
for (auto & slot : slots) {
|
for (auto & slot : slots) {
|
||||||
// this slot still has a prompt to be processed
|
// this slot still has a prompt to be processed
|
||||||
if (slot.state == SLOT_STATE_PROCESSING_PROMPT) {
|
if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) {
|
||||||
auto & prompt_tokens = slot.prompt_tokens;
|
auto & prompt_tokens = slot.prompt_tokens;
|
||||||
|
|
||||||
// we haven't tokenized the prompt yet - do it now:
|
// TODO: maybe move branch to outside of this loop in the future
|
||||||
if (prompt_tokens.empty()) {
|
if (slot.state == SLOT_STATE_STARTED) {
|
||||||
SLT_INF(slot, "tokenizing prompt, len = %d\n", (int) slot.prompt.size());
|
|
||||||
|
|
||||||
slot.t_start_process_prompt = ggml_time_us();
|
slot.t_start_process_prompt = ggml_time_us();
|
||||||
slot.t_start_generation = 0;
|
slot.t_start_generation = 0;
|
||||||
|
|
||||||
switch (slot.cmpl_type) {
|
|
||||||
case SERVER_TASK_CMPL_TYPE_NORMAL:
|
|
||||||
case SERVER_TASK_CMPL_TYPE_EMBEDDING:
|
|
||||||
{
|
|
||||||
prompt_tokens = tokenize(slot.prompt, llama_add_bos_token(model), true);
|
|
||||||
} break;
|
|
||||||
case SERVER_TASK_CMPL_TYPE_RERANK:
|
|
||||||
{
|
|
||||||
// require slot.prompt to be array of 2 strings
|
|
||||||
if (!slot.prompt.is_array() || slot.prompt.size() != 2) {
|
|
||||||
SLT_ERR(slot, "%s", "invalid prompt for rerank task\n");
|
|
||||||
slot.release();
|
|
||||||
send_error(slot, "invalid prompt for rerank task", ERROR_TYPE_INVALID_REQUEST);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// prompt: [BOS]query[EOS][SEP]doc[EOS]
|
|
||||||
prompt_tokens.clear();
|
|
||||||
prompt_tokens.push_back(llama_token_bos(model));
|
|
||||||
{
|
|
||||||
const auto part = tokenize(slot.prompt[0], false, false);
|
|
||||||
prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end());
|
|
||||||
}
|
|
||||||
prompt_tokens.push_back(llama_token_eos(model));
|
|
||||||
prompt_tokens.push_back(llama_token_sep(model));
|
|
||||||
{
|
|
||||||
const auto part = tokenize(slot.prompt[1], false, false);
|
|
||||||
prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end());
|
|
||||||
}
|
|
||||||
prompt_tokens.push_back(llama_token_eos(model));
|
|
||||||
} break;
|
|
||||||
case SERVER_TASK_CMPL_TYPE_INFILL:
|
|
||||||
{
|
|
||||||
// TODO: optimize this block by reducing memory allocations and movement
|
|
||||||
|
|
||||||
// use FIM repo-level pattern:
|
|
||||||
// ref: https://arxiv.org/pdf/2409.12186
|
|
||||||
//
|
|
||||||
// [FIM_REP]myproject
|
|
||||||
// [FIM_SEP]filename0
|
|
||||||
// extra chunk 0
|
|
||||||
// [FIM_SEP]filename1
|
|
||||||
// extra chunk 1
|
|
||||||
// ...
|
|
||||||
// [FIM_SEP]filename
|
|
||||||
// [FIM_PRE]prefix[FIM_SUF]suffix[FIM_MID]prompt
|
|
||||||
//
|
|
||||||
auto tokens_prefix = tokenize(slot.input_prefix, false, false);
|
|
||||||
auto tokens_suffix = tokenize(slot.input_suffix, false, false);
|
|
||||||
auto tokens_prompt = tokenize(slot.prompt, false, false);
|
|
||||||
|
|
||||||
slot.extra_tokens.clear();
|
|
||||||
if (llama_token_fim_rep(model) != LLAMA_TOKEN_NULL) {
|
|
||||||
static const auto k_fim_repo = tokenize("myproject\n", false, false);
|
|
||||||
|
|
||||||
slot.extra_tokens.push_back(llama_token_fim_rep(model));
|
|
||||||
slot.extra_tokens.insert(slot.extra_tokens.end(), k_fim_repo.begin(), k_fim_repo.end());
|
|
||||||
}
|
|
||||||
|
|
||||||
for (const auto & chunk : slot.input_extra) {
|
|
||||||
// { "text": string, "filename": string }
|
|
||||||
const std::string text = chunk.value("text", "");
|
|
||||||
const std::string filename = chunk.value("filename", "tmp");
|
|
||||||
|
|
||||||
if (llama_token_fim_sep(model) != LLAMA_TOKEN_NULL) {
|
|
||||||
const auto k_fim_file = tokenize(filename + "\n", false, false);
|
|
||||||
|
|
||||||
slot.extra_tokens.insert(slot.extra_tokens.end(), llama_token_fim_sep(model));
|
|
||||||
slot.extra_tokens.insert(slot.extra_tokens.end(), k_fim_file.begin(), k_fim_file.end());
|
|
||||||
} else {
|
|
||||||
// chunk separator in binary form to avoid confusing the AI
|
|
||||||
static const char k_chunk_prefix_str[] = {0x0a, 0x0a, 0x2d, 0x2d, 0x2d, 0x20, 0x73, 0x6e, 0x69, 0x70, 0x70, 0x65, 0x74, 0x20, 0x2d, 0x2d, 0x2d, 0x0a, 0x0a, 0x00};
|
|
||||||
static const auto k_chunk_prefix_tokens = tokenize(k_chunk_prefix_str, false, false);
|
|
||||||
|
|
||||||
slot.extra_tokens.insert(slot.extra_tokens.end(), k_chunk_prefix_tokens.begin(), k_chunk_prefix_tokens.end());
|
|
||||||
}
|
|
||||||
|
|
||||||
const auto chunk_tokens = tokenize(text, false, false);
|
|
||||||
slot.extra_tokens.insert(slot.extra_tokens.end(), chunk_tokens.begin(), chunk_tokens.end());
|
|
||||||
}
|
|
||||||
|
|
||||||
if (llama_token_fim_sep(model) != LLAMA_TOKEN_NULL) {
|
|
||||||
// TODO: current filename
|
|
||||||
static const auto k_fim_file = tokenize("filename\n", false, false);
|
|
||||||
|
|
||||||
slot.extra_tokens.insert(slot.extra_tokens.end(), llama_token_fim_sep(model));
|
|
||||||
slot.extra_tokens.insert(slot.extra_tokens.end(), k_fim_file.begin(), k_fim_file.end());
|
|
||||||
}
|
|
||||||
|
|
||||||
// for now pick FIM context to fit in a batch (ratio prefix:suffix = 3:1, TODO: configurable?)
|
|
||||||
const int n_suffix_take = std::min<int>(tokens_suffix.size(), (n_batch/4));
|
|
||||||
const int n_prefix_take = std::min<int>(tokens_prefix.size(), 3*(n_batch/4) - 3);
|
|
||||||
|
|
||||||
// fill the rest of the context with extra chunks
|
|
||||||
const int n_extra_take = std::min<int>(std::max<int>(0, slot.n_ctx - (n_batch) - 2*slot.n_predict), slot.extra_tokens.size());
|
|
||||||
|
|
||||||
tokens_prefix.erase(tokens_prefix.begin(), tokens_prefix.begin() + tokens_prefix.size() - n_prefix_take);
|
|
||||||
tokens_suffix.resize(n_suffix_take);
|
|
||||||
|
|
||||||
tokens_prefix.insert(tokens_prefix.begin(), llama_token_fim_pre(model));
|
|
||||||
tokens_prefix.insert(tokens_prefix.end(), tokens_prompt.begin(), tokens_prompt.end());
|
|
||||||
tokens_suffix.insert(tokens_suffix.begin(), llama_token_fim_suf(model));
|
|
||||||
|
|
||||||
auto embd_inp = params.spm_infill ? tokens_suffix : tokens_prefix;
|
|
||||||
auto embd_end = params.spm_infill ? tokens_prefix : tokens_suffix;
|
|
||||||
|
|
||||||
if (llama_add_bos_token(model)) {
|
|
||||||
embd_inp.insert(embd_inp.begin(), llama_token_bos(model));
|
|
||||||
}
|
|
||||||
|
|
||||||
SLT_DBG(slot, "extra: n_ctx = %d, n_extra_take = %d, n_extra = %d\n", slot.n_ctx, n_extra_take, (int) slot.extra_tokens.size());
|
|
||||||
|
|
||||||
// put the extra context before the FIM prefix
|
|
||||||
embd_inp.insert(embd_inp.begin(), slot.extra_tokens.end() - n_extra_take, slot.extra_tokens.end());
|
|
||||||
|
|
||||||
embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end());
|
|
||||||
embd_inp.push_back(llama_token_fim_mid(model));
|
|
||||||
|
|
||||||
prompt_tokens = std::move(embd_inp);
|
|
||||||
} break;
|
|
||||||
}
|
|
||||||
|
|
||||||
slot.n_past = 0;
|
slot.n_past = 0;
|
||||||
slot.n_prompt_tokens = prompt_tokens.size();
|
slot.n_prompt_tokens = prompt_tokens.size();
|
||||||
|
slot.state = SLOT_STATE_PROCESSING_PROMPT;
|
||||||
|
|
||||||
SLT_INF(slot, "prompt tokenized, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, slot.n_prompt_tokens);
|
SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, slot.n_prompt_tokens);
|
||||||
|
|
||||||
// print prompt tokens (for debugging)
|
// print prompt tokens (for debugging)
|
||||||
if (1) {
|
if (1) {
|
||||||
|
@ -2115,13 +1909,18 @@ struct server_context {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
|
if (slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING || slot.inf_type == SERVER_TASK_INF_TYPE_RERANK) {
|
||||||
// this prompt is too large to process - discard it
|
|
||||||
if (slot.n_prompt_tokens > n_ubatch) {
|
if (slot.n_prompt_tokens > n_ubatch) {
|
||||||
slot.release();
|
slot.release();
|
||||||
send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER);
|
send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (slot.n_prompt_tokens > slot.n_ctx) {
|
||||||
|
slot.release();
|
||||||
|
send_error(slot, "input is larger than the max context size. skipping", ERROR_TYPE_SERVER);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
if (!params.ctx_shift) {
|
if (!params.ctx_shift) {
|
||||||
// if context shift is disabled, we make sure prompt size is smaller than KV size
|
// if context shift is disabled, we make sure prompt size is smaller than KV size
|
||||||
|
@ -2145,7 +1944,7 @@ struct server_context {
|
||||||
const int n_block_size = n_left / 2;
|
const int n_block_size = n_left / 2;
|
||||||
const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size;
|
const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size;
|
||||||
|
|
||||||
std::vector<llama_token> new_tokens(
|
llama_tokens new_tokens(
|
||||||
prompt_tokens.begin(),
|
prompt_tokens.begin(),
|
||||||
prompt_tokens.begin() + slot.params.n_keep);
|
prompt_tokens.begin() + slot.params.n_keep);
|
||||||
|
|
||||||
|
@ -2199,7 +1998,6 @@ struct server_context {
|
||||||
|
|
||||||
for (size_t i = 0; i < n_match; i++) {
|
for (size_t i = 0; i < n_match; i++) {
|
||||||
slot.cache_tokens[head_p + i] = slot.cache_tokens[head_c + i];
|
slot.cache_tokens[head_p + i] = slot.cache_tokens[head_c + i];
|
||||||
|
|
||||||
slot.n_past++;
|
slot.n_past++;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2226,7 +2024,7 @@ struct server_context {
|
||||||
}
|
}
|
||||||
|
|
||||||
// non-causal tasks require to fit the entire prompt in the physical batch
|
// non-causal tasks require to fit the entire prompt in the physical batch
|
||||||
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
|
if (slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING || slot.inf_type == SERVER_TASK_INF_TYPE_RERANK) {
|
||||||
// cannot fit the prompt in the current batch - will try next iter
|
// cannot fit the prompt in the current batch - will try next iter
|
||||||
if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
|
if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
|
||||||
continue;
|
continue;
|
||||||
|
@ -2235,8 +2033,8 @@ struct server_context {
|
||||||
|
|
||||||
// check that we are in the right batch_type, if not defer the slot
|
// check that we are in the right batch_type, if not defer the slot
|
||||||
const bool slot_type =
|
const bool slot_type =
|
||||||
slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING ||
|
slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING ||
|
||||||
slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK ? 1 : 0;
|
slot.inf_type == SERVER_TASK_INF_TYPE_RERANK ? 1 : 0;
|
||||||
|
|
||||||
if (batch_type == -1) {
|
if (batch_type == -1) {
|
||||||
batch_type = slot_type;
|
batch_type = slot_type;
|
||||||
|
@ -2354,7 +2152,7 @@ struct server_context {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (slot.state == SLOT_STATE_DONE_PROMPT) {
|
if (slot.state == SLOT_STATE_DONE_PROMPT) {
|
||||||
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING) {
|
if (slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING) {
|
||||||
// prompt evaluated for embedding
|
// prompt evaluated for embedding
|
||||||
send_embedding(slot, batch_view);
|
send_embedding(slot, batch_view);
|
||||||
slot.release();
|
slot.release();
|
||||||
|
@ -2362,7 +2160,7 @@ struct server_context {
|
||||||
continue; // continue loop of slots
|
continue; // continue loop of slots
|
||||||
}
|
}
|
||||||
|
|
||||||
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
|
if (slot.inf_type == SERVER_TASK_INF_TYPE_RERANK) {
|
||||||
send_rerank(slot, batch_view);
|
send_rerank(slot, batch_view);
|
||||||
slot.release();
|
slot.release();
|
||||||
slot.i_batch = -1;
|
slot.i_batch = -1;
|
||||||
|
@ -2609,7 +2407,7 @@ int main(int argc, char ** argv) {
|
||||||
auto middleware_server_state = [&res_error, &state](const httplib::Request & req, httplib::Response & res) {
|
auto middleware_server_state = [&res_error, &state](const httplib::Request & req, httplib::Response & res) {
|
||||||
server_state current_state = state.load();
|
server_state current_state = state.load();
|
||||||
if (current_state == SERVER_STATE_LOADING_MODEL) {
|
if (current_state == SERVER_STATE_LOADING_MODEL) {
|
||||||
auto tmp = string_split(req.path, '.');
|
auto tmp = string_split<std::string>(req.path, '.');
|
||||||
if (req.path == "/" || tmp.back() == "html") {
|
if (req.path == "/" || tmp.back() == "html") {
|
||||||
res.set_content(reinterpret_cast<const char*>(loading_html), loading_html_len, "text/html; charset=utf-8");
|
res.set_content(reinterpret_cast<const char*>(loading_html), loading_html_len, "text/html; charset=utf-8");
|
||||||
res.status = 503;
|
res.status = 503;
|
||||||
|
@ -2916,13 +2714,13 @@ int main(int argc, char ** argv) {
|
||||||
res_ok(res, {{ "success", true }});
|
res_ok(res, {{ "success", true }});
|
||||||
};
|
};
|
||||||
|
|
||||||
const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_cmpl_type cmpl_type, json & data, httplib::Response & res) {
|
const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_inf_type inf_type, json & data, httplib::Response & res) {
|
||||||
if (ctx_server.params.embedding || ctx_server.params.reranking) {
|
if (ctx_server.params.embedding || ctx_server.params.reranking) {
|
||||||
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
|
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl(data, cmpl_type);
|
std::vector<server_task> tasks = ctx_server.create_tasks_inference(data, inf_type);
|
||||||
ctx_server.queue_results.add_waiting_tasks(tasks);
|
ctx_server.queue_results.add_waiting_tasks(tasks);
|
||||||
ctx_server.queue_tasks.post(tasks);
|
ctx_server.queue_tasks.post(tasks);
|
||||||
|
|
||||||
|
@ -2968,10 +2766,11 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
const auto handle_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
|
const auto handle_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
|
||||||
json data = json::parse(req.body);
|
json data = json::parse(req.body);
|
||||||
return handle_completions_generic(SERVER_TASK_CMPL_TYPE_NORMAL, data, res);
|
return handle_completions_generic(SERVER_TASK_INF_TYPE_COMPLETION, data, res);
|
||||||
};
|
};
|
||||||
|
|
||||||
const auto handle_infill = [&ctx_server, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
|
const auto handle_infill = [&ctx_server, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
|
||||||
|
// check model compatibility
|
||||||
std::string err;
|
std::string err;
|
||||||
if (llama_token_fim_pre(ctx_server.model) == LLAMA_TOKEN_NULL) {
|
if (llama_token_fim_pre(ctx_server.model) == LLAMA_TOKEN_NULL) {
|
||||||
err += "prefix token is missing. ";
|
err += "prefix token is missing. ";
|
||||||
|
@ -2982,14 +2781,42 @@ int main(int argc, char ** argv) {
|
||||||
if (llama_token_fim_mid(ctx_server.model) == LLAMA_TOKEN_NULL) {
|
if (llama_token_fim_mid(ctx_server.model) == LLAMA_TOKEN_NULL) {
|
||||||
err += "middle token is missing. ";
|
err += "middle token is missing. ";
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!err.empty()) {
|
if (!err.empty()) {
|
||||||
res_error(res, format_error_response(string_format("Infill is not supported by this model: %s", err.c_str()), ERROR_TYPE_NOT_SUPPORTED));
|
res_error(res, format_error_response(string_format("Infill is not supported by this model: %s", err.c_str()), ERROR_TYPE_NOT_SUPPORTED));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
json data = json::parse(req.body);
|
json data = json::parse(req.body);
|
||||||
return handle_completions_generic(SERVER_TASK_CMPL_TYPE_INFILL, data, res);
|
|
||||||
|
// validate input
|
||||||
|
if (!data.contains("input_prefix")) {
|
||||||
|
res_error(res, format_error_response("\"input_prefix\" is required", ERROR_TYPE_INVALID_REQUEST));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!data.contains("input_suffix")) {
|
||||||
|
res_error(res, format_error_response("\"input_suffix\" is required", ERROR_TYPE_INVALID_REQUEST));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (data.contains("input_extra") && !data.at("input_extra").is_array()) {
|
||||||
|
res_error(res, format_error_response("\"input_extra\" must be an array of {\"filename\": string, \"text\": string}", ERROR_TYPE_INVALID_REQUEST));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
json input_extra = json_value(data, "input_extra", json::array());
|
||||||
|
for (const auto & chunk : input_extra) {
|
||||||
|
// { "text": string, "filename": string }
|
||||||
|
if (!chunk.contains("text") || !chunk.at("text").is_string()) {
|
||||||
|
res_error(res, format_error_response("extra_context chunk must contain a \"text\" field with a string value", ERROR_TYPE_INVALID_REQUEST));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// filename is optional
|
||||||
|
if (chunk.contains("filename") && !chunk.at("filename").is_string()) {
|
||||||
|
res_error(res, format_error_response("extra_context chunk's \"filename\" field must be a string", ERROR_TYPE_INVALID_REQUEST));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
data["input_extra"] = input_extra; // default to empty array if it's not exist
|
||||||
|
|
||||||
|
return handle_completions_generic(SERVER_TASK_INF_TYPE_INFILL, data, res);
|
||||||
};
|
};
|
||||||
|
|
||||||
// TODO: maybe merge this function with "handle_completions_generic"
|
// TODO: maybe merge this function with "handle_completions_generic"
|
||||||
|
@ -3001,7 +2828,7 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
|
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
|
||||||
|
|
||||||
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl(data, SERVER_TASK_CMPL_TYPE_NORMAL);
|
std::vector<server_task> tasks = ctx_server.create_tasks_inference(data, SERVER_TASK_INF_TYPE_COMPLETION);
|
||||||
ctx_server.queue_results.add_waiting_tasks(tasks);
|
ctx_server.queue_results.add_waiting_tasks(tasks);
|
||||||
ctx_server.queue_tasks.post(tasks);
|
ctx_server.queue_tasks.post(tasks);
|
||||||
|
|
||||||
|
@ -3074,7 +2901,7 @@ int main(int argc, char ** argv) {
|
||||||
const bool add_special = json_value(body, "add_special", false);
|
const bool add_special = json_value(body, "add_special", false);
|
||||||
const bool with_pieces = json_value(body, "with_pieces", false);
|
const bool with_pieces = json_value(body, "with_pieces", false);
|
||||||
|
|
||||||
std::vector<llama_token> tokens = ctx_server.tokenize(body.at("content"), add_special, true);
|
llama_tokens tokens = tokenize_mixed(ctx_server.ctx, body.at("content"), add_special, true);
|
||||||
|
|
||||||
if (with_pieces) {
|
if (with_pieces) {
|
||||||
for (const auto& token : tokens) {
|
for (const auto& token : tokens) {
|
||||||
|
@ -3111,7 +2938,7 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
std::string content;
|
std::string content;
|
||||||
if (body.count("tokens") != 0) {
|
if (body.count("tokens") != 0) {
|
||||||
const std::vector<llama_token> tokens = body.at("tokens");
|
const llama_tokens tokens = body.at("tokens");
|
||||||
content = tokens_to_str(ctx_server.ctx, tokens.cbegin(), tokens.cend());
|
content = tokens_to_str(ctx_server.ctx, tokens.cbegin(), tokens.cend());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3145,7 +2972,7 @@ int main(int argc, char ** argv) {
|
||||||
json responses = json::array();
|
json responses = json::array();
|
||||||
bool error = false;
|
bool error = false;
|
||||||
{
|
{
|
||||||
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_EMBEDDING);
|
std::vector<server_task> tasks = ctx_server.create_tasks_inference({{"prompt", prompt}}, SERVER_TASK_INF_TYPE_EMBEDDING);
|
||||||
ctx_server.queue_results.add_waiting_tasks(tasks);
|
ctx_server.queue_results.add_waiting_tasks(tasks);
|
||||||
ctx_server.queue_tasks.post(tasks);
|
ctx_server.queue_tasks.post(tasks);
|
||||||
|
|
||||||
|
@ -3222,7 +3049,7 @@ int main(int argc, char ** argv) {
|
||||||
json responses = json::array();
|
json responses = json::array();
|
||||||
bool error = false;
|
bool error = false;
|
||||||
{
|
{
|
||||||
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_RERANK);
|
std::vector<server_task> tasks = ctx_server.create_tasks_inference({{"prompt", prompt}}, SERVER_TASK_INF_TYPE_RERANK);
|
||||||
ctx_server.queue_results.add_waiting_tasks(tasks);
|
ctx_server.queue_results.add_waiting_tasks(tasks);
|
||||||
ctx_server.queue_tasks.post(tasks);
|
ctx_server.queue_tasks.post(tasks);
|
||||||
|
|
||||||
|
|
36
examples/server/tests/features/infill.feature
Normal file
36
examples/server/tests/features/infill.feature
Normal file
|
@ -0,0 +1,36 @@
|
||||||
|
@llama.cpp
|
||||||
|
@infill
|
||||||
|
Feature: llama.cpp server
|
||||||
|
|
||||||
|
# The current model is made by adding FIM tokens to the existing stories260K
|
||||||
|
# We may want to use a better model in the future, maybe something like SmolLM 360M
|
||||||
|
|
||||||
|
Background: Server startup
|
||||||
|
Given a server listening on localhost:8080
|
||||||
|
And a model file tinyllamas/stories260K-infill.gguf from HF repo ggml-org/models
|
||||||
|
And a model file test-model-infill.gguf
|
||||||
|
And a model alias tinyllama-infill
|
||||||
|
And 42 as server seed
|
||||||
|
And 1024 as batch size
|
||||||
|
And 1024 as ubatch size
|
||||||
|
And 2048 KV cache size
|
||||||
|
And 64 max tokens to predict
|
||||||
|
And 0.0 temperature
|
||||||
|
Then the server is starting
|
||||||
|
Then the server is healthy
|
||||||
|
|
||||||
|
Scenario: Infill without input_extra
|
||||||
|
Given a prompt "Complete this"
|
||||||
|
And an infill input extra none none
|
||||||
|
And an infill input prefix "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n int n_threads = llama_"
|
||||||
|
And an infill input suffix "}\n"
|
||||||
|
And an infill request with no api error
|
||||||
|
Then 64 tokens are predicted matching One|day|she|saw|big|scary|bird
|
||||||
|
|
||||||
|
Scenario: Infill with input_extra
|
||||||
|
Given a prompt "Complete this"
|
||||||
|
And an infill input extra "llama.h" "LLAMA_API int32_t llama_n_threads();\n"
|
||||||
|
And an infill input prefix "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n int n_threads = llama_"
|
||||||
|
And an infill input suffix "}\n"
|
||||||
|
And an infill request with no api error
|
||||||
|
Then 64 tokens are predicted matching cuts|Jimmy|mom|came|into|the|room"
|
|
@ -80,6 +80,11 @@ def step_server_config(context, server_fqdn: str, server_port: str):
|
||||||
context.lora_file = None
|
context.lora_file = None
|
||||||
context.disable_ctx_shift = False
|
context.disable_ctx_shift = False
|
||||||
|
|
||||||
|
# infill
|
||||||
|
context.infill_input_extra = None
|
||||||
|
context.infill_input_suffix = ''
|
||||||
|
context.infill_input_prefix = ''
|
||||||
|
|
||||||
context.tasks_result = []
|
context.tasks_result = []
|
||||||
context.concurrent_tasks = []
|
context.concurrent_tasks = []
|
||||||
context.prompts = []
|
context.prompts = []
|
||||||
|
@ -291,6 +296,28 @@ async def step_request_completion(context, api_error: Literal['raised'] | str):
|
||||||
assert completion == api_error_code, f"completion must be an {api_error_code} status code: {completion}"
|
assert completion == api_error_code, f"completion must be an {api_error_code} status code: {completion}"
|
||||||
|
|
||||||
|
|
||||||
|
@step('an infill request with {api_error} api error')
|
||||||
|
@async_run_until_complete
|
||||||
|
async def step_request_completion(context, api_error: Literal['raised'] | str):
|
||||||
|
if api_error != 'no':
|
||||||
|
raise ValueError(f'api_error={api_error} is not yet implemented')
|
||||||
|
payload = {
|
||||||
|
"prompt": context.prompts[0],
|
||||||
|
"input_suffix": context.infill_input_suffix,
|
||||||
|
"input_prefix": context.infill_input_prefix,
|
||||||
|
"n_predict": context.n_predict,
|
||||||
|
"seed": context.seed,
|
||||||
|
"temperature": context.temperature,
|
||||||
|
}
|
||||||
|
if context.infill_input_extra is not None:
|
||||||
|
payload['input_extra'] = context.infill_input_extra
|
||||||
|
async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session:
|
||||||
|
async with session.post(f'{context.base_url}/infill',
|
||||||
|
json=payload) as response:
|
||||||
|
assert response.status == 200
|
||||||
|
context.tasks_result = [await response.json()]
|
||||||
|
|
||||||
|
|
||||||
@step('{predicted_n:d} tokens are predicted matching {re_content}')
|
@step('{predicted_n:d} tokens are predicted matching {re_content}')
|
||||||
def step_n_tokens_predicted_with_content(context, predicted_n, re_content):
|
def step_n_tokens_predicted_with_content(context, predicted_n, re_content):
|
||||||
context.completion = context.tasks_result.pop()
|
context.completion = context.tasks_result.pop()
|
||||||
|
@ -539,6 +566,25 @@ def step_a_prompt_prompt(context, prompt):
|
||||||
context.n_prompts = len(context.prompts)
|
context.n_prompts = len(context.prompts)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: allow this to be repeated
|
||||||
|
@step('an infill input extra {filename} {text}')
|
||||||
|
def step_infill_input_extra(context, filename, text):
|
||||||
|
if filename == 'none':
|
||||||
|
context.infill_input_extra = None
|
||||||
|
else:
|
||||||
|
context.infill_input_extra = [{'filename': filename, 'text': text}]
|
||||||
|
|
||||||
|
|
||||||
|
@step('an infill input suffix {text}')
|
||||||
|
def step_infill_input_suffix(context, text):
|
||||||
|
context.infill_input_suffix = text
|
||||||
|
|
||||||
|
|
||||||
|
@step('an infill input prefix {text}')
|
||||||
|
def step_infill_input_prefix(context, text):
|
||||||
|
context.infill_input_prefix = text
|
||||||
|
|
||||||
|
|
||||||
@step('{num_prompts:d} prompts {prompt} with seed {seed:d}')
|
@step('{num_prompts:d} prompts {prompt} with seed {seed:d}')
|
||||||
def step_many_prompts(context, num_prompts, prompt, seed):
|
def step_many_prompts(context, num_prompts, prompt, seed):
|
||||||
if context.seed is None:
|
if context.seed is None:
|
||||||
|
|
|
@ -226,7 +226,6 @@
|
||||||
top_k: 40, // <= 0 to use vocab size
|
top_k: 40, // <= 0 to use vocab size
|
||||||
top_p: 0.95, // 1.0 = disabled
|
top_p: 0.95, // 1.0 = disabled
|
||||||
min_p: 0.05, // 0 = disabled
|
min_p: 0.05, // 0 = disabled
|
||||||
tfs_z: 1.0, // 1.0 = disabled
|
|
||||||
typical_p: 1.0, // 1.0 = disabled
|
typical_p: 1.0, // 1.0 = disabled
|
||||||
presence_penalty: 0.0, // 0.0 = disabled
|
presence_penalty: 0.0, // 0.0 = disabled
|
||||||
frequency_penalty: 0.0, // 0.0 = disabled
|
frequency_penalty: 0.0, // 0.0 = disabled
|
||||||
|
@ -788,7 +787,6 @@
|
||||||
<details>
|
<details>
|
||||||
<summary>More options</summary>
|
<summary>More options</summary>
|
||||||
<fieldset class="two">
|
<fieldset class="two">
|
||||||
${FloatField({ label: "TFS-Z", max: 1.0, min: 0.0, name: "tfs_z", step: 0.01, value: params.value.tfs_z })}
|
|
||||||
${FloatField({ label: "Typical P", max: 1.0, min: 0.0, name: "typical_p", step: 0.01, value: params.value.typical_p })}
|
${FloatField({ label: "Typical P", max: 1.0, min: 0.0, name: "typical_p", step: 0.01, value: params.value.typical_p })}
|
||||||
${FloatField({ label: "Presence penalty", max: 1.0, min: 0.0, name: "presence_penalty", step: 0.01, value: params.value.presence_penalty })}
|
${FloatField({ label: "Presence penalty", max: 1.0, min: 0.0, name: "presence_penalty", step: 0.01, value: params.value.presence_penalty })}
|
||||||
${FloatField({ label: "Frequency penalty", max: 1.0, min: 0.0, name: "frequency_penalty", step: 0.01, value: params.value.frequency_penalty })}
|
${FloatField({ label: "Frequency penalty", max: 1.0, min: 0.0, name: "frequency_penalty", step: 0.01, value: params.value.frequency_penalty })}
|
||||||
|
|
|
@ -229,7 +229,6 @@
|
||||||
top_k: 40, // <= 0 to use vocab size
|
top_k: 40, // <= 0 to use vocab size
|
||||||
top_p: 0.95, // 1.0 = disabled
|
top_p: 0.95, // 1.0 = disabled
|
||||||
min_p: 0.05, // 0 = disabled
|
min_p: 0.05, // 0 = disabled
|
||||||
tfs_z: 1.0, // 1.0 = disabled
|
|
||||||
typical_p: 1.0, // 1.0 = disabled
|
typical_p: 1.0, // 1.0 = disabled
|
||||||
presence_penalty: 0.0, // 0.0 = disabled
|
presence_penalty: 0.0, // 0.0 = disabled
|
||||||
frequency_penalty: 0.0, // 0.0 = disabled
|
frequency_penalty: 0.0, // 0.0 = disabled
|
||||||
|
@ -791,7 +790,6 @@
|
||||||
<details>
|
<details>
|
||||||
<summary>More options</summary>
|
<summary>More options</summary>
|
||||||
<fieldset class="two">
|
<fieldset class="two">
|
||||||
${FloatField({ label: "TFS-Z", max: 1.0, min: 0.0, name: "tfs_z", step: 0.01, value: params.value.tfs_z })}
|
|
||||||
${FloatField({ label: "Typical P", max: 1.0, min: 0.0, name: "typical_p", step: 0.01, value: params.value.typical_p })}
|
${FloatField({ label: "Typical P", max: 1.0, min: 0.0, name: "typical_p", step: 0.01, value: params.value.typical_p })}
|
||||||
${FloatField({ label: "Presence penalty", max: 1.0, min: 0.0, name: "presence_penalty", step: 0.01, value: params.value.presence_penalty })}
|
${FloatField({ label: "Presence penalty", max: 1.0, min: 0.0, name: "presence_penalty", step: 0.01, value: params.value.presence_penalty })}
|
||||||
${FloatField({ label: "Frequency penalty", max: 1.0, min: 0.0, name: "frequency_penalty", step: 0.01, value: params.value.frequency_penalty })}
|
${FloatField({ label: "Frequency penalty", max: 1.0, min: 0.0, name: "frequency_penalty", step: 0.01, value: params.value.frequency_penalty })}
|
||||||
|
|
|
@ -24,6 +24,22 @@
|
||||||
#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613"
|
#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613"
|
||||||
|
|
||||||
using json = nlohmann::ordered_json;
|
using json = nlohmann::ordered_json;
|
||||||
|
using llama_tokens = std::vector<llama_token>;
|
||||||
|
|
||||||
|
#define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
|
||||||
|
#define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
|
||||||
|
#define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
|
||||||
|
#define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
|
||||||
|
|
||||||
|
#define SRV_INF(fmt, ...) LOG_INF("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||||
|
#define SRV_WRN(fmt, ...) LOG_WRN("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||||
|
#define SRV_ERR(fmt, ...) LOG_ERR("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||||
|
#define SRV_DBG(fmt, ...) LOG_DBG("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||||
|
|
||||||
|
#define QUE_INF(fmt, ...) LOG_INF("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||||
|
#define QUE_WRN(fmt, ...) LOG_WRN("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||||
|
#define QUE_ERR(fmt, ...) LOG_ERR("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||||
|
#define QUE_DBG(fmt, ...) LOG_DBG("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||||
|
|
||||||
// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11
|
// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11
|
||||||
enum error_type {
|
enum error_type {
|
||||||
|
@ -52,9 +68,237 @@ static T json_value(const json & body, const std::string & key, const T & defaul
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// chat template utils
|
// tokenizer and input processing utils
|
||||||
//
|
//
|
||||||
|
|
||||||
|
static bool json_is_array_of_numbers(const json & data) {
|
||||||
|
if (data.is_array()) {
|
||||||
|
for (const auto & e : data) {
|
||||||
|
if (!e.is_number_integer()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// is array having BOTH numbers & strings?
|
||||||
|
static bool json_is_array_of_mixed_numbers_strings(const json & data) {
|
||||||
|
bool seen_string = false;
|
||||||
|
bool seen_number = false;
|
||||||
|
if (data.is_array()) {
|
||||||
|
for (const auto & e : data) {
|
||||||
|
seen_string |= e.is_string();
|
||||||
|
seen_number |= e.is_number_integer();
|
||||||
|
if (seen_number && seen_string) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* this handles 2 cases:
|
||||||
|
* - only string, example: "string"
|
||||||
|
* - mixed string and tokens, example: [12, 34, "string", 56, 78]
|
||||||
|
*/
|
||||||
|
static llama_tokens tokenize_mixed(const llama_context * ctx, const json & json_prompt, bool add_special, bool parse_special) {
|
||||||
|
// If `add_bos` is true, we only add BOS, when json_prompt is a string,
|
||||||
|
// or the first element of the json_prompt array is a string.
|
||||||
|
llama_tokens prompt_tokens;
|
||||||
|
|
||||||
|
if (json_prompt.is_array()) {
|
||||||
|
bool first = true;
|
||||||
|
for (const auto & p : json_prompt) {
|
||||||
|
if (p.is_string()) {
|
||||||
|
auto s = p.template get<std::string>();
|
||||||
|
|
||||||
|
llama_tokens p;
|
||||||
|
if (first) {
|
||||||
|
p = common_tokenize(ctx, s, add_special, parse_special);
|
||||||
|
first = false;
|
||||||
|
} else {
|
||||||
|
p = common_tokenize(ctx, s, false, parse_special);
|
||||||
|
}
|
||||||
|
|
||||||
|
prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end());
|
||||||
|
} else {
|
||||||
|
if (first) {
|
||||||
|
first = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
prompt_tokens.push_back(p.template get<llama_token>());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
auto s = json_prompt.template get<std::string>();
|
||||||
|
prompt_tokens = common_tokenize(ctx, s, add_special, parse_special);
|
||||||
|
}
|
||||||
|
|
||||||
|
return prompt_tokens;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* break the input "prompt" object into multiple prompt if needed, then tokenize them
|
||||||
|
* this supports these cases:
|
||||||
|
* - "prompt": "string"
|
||||||
|
* - "prompt": [12, 34, 56]
|
||||||
|
* - "prompt": [12, 34, "string", 56, 78]
|
||||||
|
* and multiple prompts (multi-tasks):
|
||||||
|
* - "prompt": ["string1", "string2"]
|
||||||
|
* - "prompt": ["string1", [12, 34, 56]]
|
||||||
|
* - "prompt": [[12, 34, "string", 56, 78], [12, 34, 56]]
|
||||||
|
*/
|
||||||
|
static std::vector<llama_tokens> tokenize_input_prompts(llama_context * ctx, const json & json_prompt, bool add_special, bool parse_special) {
|
||||||
|
std::vector<llama_tokens> result;
|
||||||
|
if (json_prompt.is_string() || json_is_array_of_mixed_numbers_strings(json_prompt)) {
|
||||||
|
// string or mixed
|
||||||
|
result.push_back(tokenize_mixed(ctx, json_prompt, add_special, parse_special));
|
||||||
|
} else if (json_is_array_of_numbers(json_prompt)) {
|
||||||
|
// array of tokens
|
||||||
|
result.push_back(json_prompt.get<llama_tokens>());
|
||||||
|
} else if (json_prompt.is_array()) {
|
||||||
|
// array of prompts
|
||||||
|
result.reserve(json_prompt.size());
|
||||||
|
for (const auto & p : json_prompt) {
|
||||||
|
if (p.is_string() || json_is_array_of_mixed_numbers_strings(p)) {
|
||||||
|
result.push_back(tokenize_mixed(ctx, p, add_special, parse_special));
|
||||||
|
} else if (json_is_array_of_numbers(p)) {
|
||||||
|
// array of tokens
|
||||||
|
result.push_back(p.get<llama_tokens>());
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error("element of \"prompt\" must be a string, an list of tokens, or a list of mixed strings & tokens");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error("\"prompt\" must be a string, an list of tokens, a list of mixed strings & tokens, or a list of prompts");
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// template utils
|
||||||
|
//
|
||||||
|
|
||||||
|
// format rerank task: [BOS]query[EOS][SEP]doc[EOS]
|
||||||
|
static llama_tokens format_rerank(const struct llama_model * model, const llama_tokens & query, const llama_tokens & doc) {
|
||||||
|
llama_tokens result;
|
||||||
|
result.reserve(doc.size() + query.size() + 4);
|
||||||
|
result.push_back(llama_token_bos(model));
|
||||||
|
result.insert(result.end(), query.begin(), query.end());
|
||||||
|
result.push_back(llama_token_eos(model));
|
||||||
|
result.push_back(llama_token_sep(model));
|
||||||
|
result.insert(result.end(), doc.begin(), doc.end());
|
||||||
|
result.push_back(llama_token_eos(model));
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
// format infill task
|
||||||
|
static llama_tokens format_infill(
|
||||||
|
const llama_context * ctx,
|
||||||
|
const json & input_prefix,
|
||||||
|
const json & input_suffix,
|
||||||
|
const json & input_extra,
|
||||||
|
const int n_batch,
|
||||||
|
const int n_predict,
|
||||||
|
const int n_ctx,
|
||||||
|
const bool spm_infill,
|
||||||
|
const llama_tokens & tokens_prompt
|
||||||
|
) {
|
||||||
|
// TODO: optimize this block by reducing memory allocations and movement
|
||||||
|
|
||||||
|
// use FIM repo-level pattern:
|
||||||
|
// ref: https://arxiv.org/pdf/2409.12186
|
||||||
|
//
|
||||||
|
// [FIM_REP]myproject
|
||||||
|
// [FIM_SEP]filename0
|
||||||
|
// extra chunk 0
|
||||||
|
// [FIM_SEP]filename1
|
||||||
|
// extra chunk 1
|
||||||
|
// ...
|
||||||
|
// [FIM_SEP]filename
|
||||||
|
// [FIM_PRE]prefix[FIM_SUF]suffix[FIM_MID]prompt
|
||||||
|
//
|
||||||
|
llama_tokens extra_tokens;
|
||||||
|
extra_tokens.reserve(n_ctx);
|
||||||
|
|
||||||
|
auto model = llama_get_model(ctx);
|
||||||
|
auto tokens_prefix = tokenize_mixed(ctx, input_prefix, false, false);
|
||||||
|
auto tokens_suffix = tokenize_mixed(ctx, input_suffix, false, false);
|
||||||
|
|
||||||
|
if (llama_token_fim_rep(model) != LLAMA_TOKEN_NULL) {
|
||||||
|
// TODO: make project name an input
|
||||||
|
static const auto k_fim_repo = common_tokenize(ctx, "myproject\n", false, false);
|
||||||
|
|
||||||
|
extra_tokens.push_back(llama_token_fim_rep(model));
|
||||||
|
extra_tokens.insert(extra_tokens.end(), k_fim_repo.begin(), k_fim_repo.end());
|
||||||
|
}
|
||||||
|
for (const auto & chunk : input_extra) {
|
||||||
|
// { "text": string, "filename": string }
|
||||||
|
const std::string text = json_value(chunk, "text", std::string());
|
||||||
|
const std::string filename = json_value(chunk, "filename", std::string("tmp"));
|
||||||
|
|
||||||
|
if (llama_token_fim_sep(model) != LLAMA_TOKEN_NULL) {
|
||||||
|
const auto k_fim_file = common_tokenize(ctx, filename + "\n", false, false);
|
||||||
|
|
||||||
|
extra_tokens.insert(extra_tokens.end(), llama_token_fim_sep(model));
|
||||||
|
extra_tokens.insert(extra_tokens.end(), k_fim_file.begin(), k_fim_file.end());
|
||||||
|
} else {
|
||||||
|
// chunk separator in binary form to avoid confusing the AI
|
||||||
|
static const char k_chunk_prefix_str[] = {0x0a, 0x0a, 0x2d, 0x2d, 0x2d, 0x20, 0x73, 0x6e, 0x69, 0x70, 0x70, 0x65, 0x74, 0x20, 0x2d, 0x2d, 0x2d, 0x0a, 0x0a, 0x00};
|
||||||
|
static const auto k_chunk_prefix_tokens = common_tokenize(ctx, k_chunk_prefix_str, false, false);
|
||||||
|
|
||||||
|
extra_tokens.insert(extra_tokens.end(), k_chunk_prefix_tokens.begin(), k_chunk_prefix_tokens.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto chunk_tokens = common_tokenize(ctx, text, false, false);
|
||||||
|
extra_tokens.insert(extra_tokens.end(), chunk_tokens.begin(), chunk_tokens.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (llama_token_fim_sep(model) != LLAMA_TOKEN_NULL) {
|
||||||
|
// TODO: current filename
|
||||||
|
static const auto k_fim_file = common_tokenize(ctx, "filename\n", false, false);
|
||||||
|
|
||||||
|
extra_tokens.insert(extra_tokens.end(), llama_token_fim_sep(model));
|
||||||
|
extra_tokens.insert(extra_tokens.end(), k_fim_file.begin(), k_fim_file.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
// for now pick FIM context to fit in a batch (ratio prefix:suffix = 3:1, TODO: configurable?)
|
||||||
|
const int n_prefix_take = std::min<int>(tokens_prefix.size(), 3*(n_batch/4));
|
||||||
|
const int n_suffix_take = std::min<int>(tokens_suffix.size(), std::max<int>(0, (n_batch/4) - (2 + tokens_prompt.size())));
|
||||||
|
|
||||||
|
SRV_DBG("n_prefix_take = %d, n_suffix_take = %d, total = %d\n", n_prefix_take, n_suffix_take, (n_prefix_take + n_suffix_take));
|
||||||
|
|
||||||
|
// fill the rest of the context with extra chunks
|
||||||
|
const int n_extra_take = std::min<int>(std::max<int>(0, n_ctx - (n_batch) - 2*n_predict), extra_tokens.size());
|
||||||
|
|
||||||
|
tokens_prefix.erase(tokens_prefix.begin(), tokens_prefix.begin() + tokens_prefix.size() - n_prefix_take);
|
||||||
|
tokens_suffix.resize(n_suffix_take);
|
||||||
|
|
||||||
|
tokens_prefix.insert(tokens_prefix.begin(), llama_token_fim_pre(model));
|
||||||
|
tokens_prefix.insert(tokens_prefix.end(), tokens_prompt.begin(), tokens_prompt.end());
|
||||||
|
tokens_suffix.insert(tokens_suffix.begin(), llama_token_fim_suf(model));
|
||||||
|
|
||||||
|
auto embd_inp = spm_infill ? tokens_suffix : tokens_prefix;
|
||||||
|
auto embd_end = spm_infill ? tokens_prefix : tokens_suffix;
|
||||||
|
|
||||||
|
if (llama_add_bos_token(model)) {
|
||||||
|
embd_inp.insert(embd_inp.begin(), llama_token_bos(model));
|
||||||
|
}
|
||||||
|
|
||||||
|
SRV_DBG("extra: n_ctx = %d, n_extra_take = %d, n_extra = %d\n", n_ctx, n_extra_take, (int) extra_tokens.size());
|
||||||
|
|
||||||
|
// put the extra context before the FIM prefix
|
||||||
|
embd_inp.insert(embd_inp.begin(), extra_tokens.end() - n_extra_take, extra_tokens.end());
|
||||||
|
|
||||||
|
embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end());
|
||||||
|
embd_inp.push_back(llama_token_fim_mid(model));
|
||||||
|
|
||||||
|
return embd_inp;
|
||||||
|
}
|
||||||
|
|
||||||
// Format given chat. If tmpl is empty, we take the template from model metadata
|
// Format given chat. If tmpl is empty, we take the template from model metadata
|
||||||
inline std::string format_chat(const struct llama_model * model, const std::string & tmpl, const std::vector<json> & messages) {
|
inline std::string format_chat(const struct llama_model * model, const std::string & tmpl, const std::vector<json> & messages) {
|
||||||
std::vector<common_chat_msg> chat;
|
std::vector<common_chat_msg> chat;
|
||||||
|
@ -229,18 +473,6 @@ static size_t find_partial_stop_string(const std::string &stop, const std::strin
|
||||||
return std::string::npos;
|
return std::string::npos;
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool json_is_array_of_numbers(const json & data) {
|
|
||||||
if (data.is_array()) {
|
|
||||||
for (const auto & e : data) {
|
|
||||||
if (!e.is_number()) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: reuse llama_detokenize
|
// TODO: reuse llama_detokenize
|
||||||
template <class Iter>
|
template <class Iter>
|
||||||
static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) {
|
static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) {
|
||||||
|
@ -375,7 +607,7 @@ static json oaicompat_completion_params_parse(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Copy remaining properties to llama_params
|
// Copy remaining properties to llama_params
|
||||||
// This allows user to use llama.cpp-specific params like "mirostat", "tfs_z",... via OAI endpoint.
|
// This allows user to use llama.cpp-specific params like "mirostat", ... via OAI endpoint.
|
||||||
// See "launch_slot_with_task()" for a complete list of params supported by llama.cpp
|
// See "launch_slot_with_task()" for a complete list of params supported by llama.cpp
|
||||||
for (const auto & item : body.items()) {
|
for (const auto & item : body.items()) {
|
||||||
// Exception: if "n_predict" is present, we overwrite the value specified earlier by "max_tokens"
|
// Exception: if "n_predict" is present, we overwrite the value specified earlier by "max_tokens"
|
||||||
|
|
|
@ -1484,14 +1484,19 @@ static void ggml_cuda_op_mul_mat(
|
||||||
const size_t nbytes_data = ggml_nbytes(src0);
|
const size_t nbytes_data = ggml_nbytes(src0);
|
||||||
const size_t nbytes_padding = ggml_row_size(src0->type, MATRIX_ROW_PADDING - ne00 % MATRIX_ROW_PADDING);
|
const size_t nbytes_padding = ggml_row_size(src0->type, MATRIX_ROW_PADDING - ne00 % MATRIX_ROW_PADDING);
|
||||||
dev[id].src0_dd = dev[id].src0_dd_alloc.alloc(ctx.pool(id), nbytes_data + nbytes_padding);
|
dev[id].src0_dd = dev[id].src0_dd_alloc.alloc(ctx.pool(id), nbytes_data + nbytes_padding);
|
||||||
|
// TODO: remove this for MUSA once the Guilty Lockup issue is resolved
|
||||||
|
#ifndef GGML_USE_MUSA
|
||||||
CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd, 0, nbytes_data + nbytes_padding, stream));
|
CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd, 0, nbytes_data + nbytes_padding, stream));
|
||||||
|
#else // GGML_USE_MUSA
|
||||||
|
CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd + nbytes_data, 0, nbytes_padding, stream));
|
||||||
|
#endif // !GGML_USE_MUSA
|
||||||
}
|
}
|
||||||
|
|
||||||
// If src0 is on a temporary compute buffer (partial offloading) there may be some padding that needs to be cleared:
|
// If src0 is on a temporary compute buffer (partial offloading) there may be some padding that needs to be cleared:
|
||||||
if (ne00 % MATRIX_ROW_PADDING != 0 && ggml_is_quantized(src0->type) && ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE && src0->view_src == nullptr) {
|
if (ne00 % MATRIX_ROW_PADDING != 0 && ggml_is_quantized(src0->type) && ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE && src0->view_src == nullptr) {
|
||||||
const size_t nbytes_data = ggml_row_size(src0->type, (dev[id].row_high - dev[id].row_low)*ne00);
|
const size_t nbytes_data = ggml_row_size(src0->type, (dev[id].row_high - dev[id].row_low)*ne00);
|
||||||
const size_t nbytes_padding = ggml_row_size(src0->type, MATRIX_ROW_PADDING - ne00 % MATRIX_ROW_PADDING);
|
const size_t nbytes_padding = ggml_row_size(src0->type, MATRIX_ROW_PADDING - ne00 % MATRIX_ROW_PADDING);
|
||||||
CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd + nbytes_data , 0, nbytes_padding, stream));
|
CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd + nbytes_data, 0, nbytes_padding, stream));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (src1_on_device && src1_is_contiguous) {
|
if (src1_on_device && src1_is_contiguous) {
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
#include "common.cuh"
|
#include "common.cuh"
|
||||||
|
|
||||||
#define CUDA_CPY_BLOCK_SIZE 32
|
#define CUDA_CPY_BLOCK_SIZE 64
|
||||||
|
|
||||||
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1);
|
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1);
|
||||||
|
|
||||||
|
|
|
@ -1015,19 +1015,21 @@ static void ggml_metal_encode_node(
|
||||||
id<MTLBuffer> id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil;
|
id<MTLBuffer> id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil;
|
||||||
id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil;
|
id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil;
|
||||||
|
|
||||||
//GGML_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
|
#if 0
|
||||||
//if (src0) {
|
GGML_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
|
||||||
// GGML_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02,
|
if (src0) {
|
||||||
// ggml_is_contiguous(src0), src0->name);
|
GGML_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03,
|
||||||
//}
|
ggml_is_contiguous(src0), src0->name);
|
||||||
//if (src1) {
|
}
|
||||||
// GGML_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12,
|
if (src1) {
|
||||||
// ggml_is_contiguous(src1), src1->name);
|
GGML_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13,
|
||||||
//}
|
ggml_is_contiguous(src1), src1->name);
|
||||||
//if (dst) {
|
}
|
||||||
// GGML_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2,
|
if (dst) {
|
||||||
// dst->name);
|
GGML_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3,
|
||||||
//}
|
dst->name);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
id<MTLDevice> device = ctx_dev->mtl_device;
|
id<MTLDevice> device = ctx_dev->mtl_device;
|
||||||
|
|
||||||
|
@ -1810,14 +1812,16 @@ static void ggml_metal_encode_node(
|
||||||
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
||||||
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
|
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
|
||||||
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
|
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
|
||||||
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
|
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:7];
|
||||||
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8];
|
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:8];
|
||||||
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9];
|
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:9];
|
||||||
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
|
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:10];
|
||||||
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
|
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:11];
|
||||||
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
|
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:12];
|
||||||
[encoder setBytes:&r2 length:sizeof(r2) atIndex:13];
|
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
|
||||||
[encoder setBytes:&r3 length:sizeof(r3) atIndex:14];
|
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
|
||||||
|
[encoder setBytes:&r2 length:sizeof(r2) atIndex:15];
|
||||||
|
[encoder setBytes:&r3 length:sizeof(r3) atIndex:16];
|
||||||
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
||||||
} else {
|
} else {
|
||||||
|
@ -1986,20 +1990,22 @@ static void ggml_metal_encode_node(
|
||||||
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
||||||
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
||||||
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
||||||
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9];
|
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
|
||||||
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10];
|
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
|
||||||
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:11];
|
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
|
||||||
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:12];
|
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
|
||||||
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:13];
|
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:13];
|
||||||
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
|
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:14];
|
||||||
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
|
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:15];
|
||||||
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
|
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:16];
|
||||||
[encoder setBytes:&r2 length:sizeof(r2) atIndex:17];
|
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
|
||||||
[encoder setBytes:&r3 length:sizeof(r3) atIndex:18];
|
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:18];
|
||||||
|
[encoder setBytes:&r2 length:sizeof(r2) atIndex:19];
|
||||||
|
[encoder setBytes:&r3 length:sizeof(r3) atIndex:20];
|
||||||
|
|
||||||
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
|
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
|
||||||
src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
|
src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
|
||||||
src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
|
src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||||
}
|
}
|
||||||
else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
|
else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
|
||||||
|
@ -2048,6 +2054,9 @@ static void ggml_metal_encode_node(
|
||||||
|
|
||||||
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
||||||
|
|
||||||
|
GGML_ASSERT(ne03 == 1);
|
||||||
|
GGML_ASSERT(ne13 == 1);
|
||||||
|
|
||||||
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
|
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
|
||||||
// to the matrix-vector kernel
|
// to the matrix-vector kernel
|
||||||
// ne20 = n_used_experts
|
// ne20 = n_used_experts
|
||||||
|
|
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
|
@ -213,6 +213,7 @@ struct vk_device_struct {
|
||||||
vk_pipeline pipeline_sum_rows_f32;
|
vk_pipeline pipeline_sum_rows_f32;
|
||||||
vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
|
vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
|
||||||
vk_pipeline pipeline_timestep_embedding_f32;
|
vk_pipeline pipeline_timestep_embedding_f32;
|
||||||
|
vk_pipeline pipeline_pool2d_f32;
|
||||||
|
|
||||||
std::unordered_map<std::string, vk_pipeline_ref> pipelines;
|
std::unordered_map<std::string, vk_pipeline_ref> pipelines;
|
||||||
std::unordered_map<std::string, uint64_t> pipeline_descriptor_set_requirements;
|
std::unordered_map<std::string, uint64_t> pipeline_descriptor_set_requirements;
|
||||||
|
@ -403,6 +404,17 @@ struct vk_op_timestep_embedding_push_constants {
|
||||||
uint32_t max_period;
|
uint32_t max_period;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct vk_op_pool2d_push_constants {
|
||||||
|
uint32_t IW; uint32_t IH;
|
||||||
|
uint32_t OW; uint32_t OH;
|
||||||
|
uint32_t OC;
|
||||||
|
uint32_t pelements;
|
||||||
|
uint32_t op;
|
||||||
|
int32_t k0; int32_t k1;
|
||||||
|
int32_t s0; int32_t s1;
|
||||||
|
int32_t p0; int32_t p1;
|
||||||
|
};
|
||||||
|
|
||||||
// Allow pre-recording command buffers
|
// Allow pre-recording command buffers
|
||||||
struct vk_staging_memcpy {
|
struct vk_staging_memcpy {
|
||||||
vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
|
vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
|
||||||
|
@ -1803,6 +1815,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1);
|
||||||
|
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_pool2d_f32, "pool2d_f32", pool2d_f32_len, pool2d_f32_data, "main", 2, sizeof(vk_op_pool2d_push_constants), {512, 1, 1}, {}, 1);
|
||||||
|
|
||||||
for (auto &c : compiles) {
|
for (auto &c : compiles) {
|
||||||
c.wait();
|
c.wait();
|
||||||
}
|
}
|
||||||
|
@ -4234,6 +4248,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||||
return ctx->device->pipeline_timestep_embedding_f32;
|
return ctx->device->pipeline_timestep_embedding_f32;
|
||||||
}
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
case GGML_OP_POOL_2D:
|
||||||
|
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||||
|
return ctx->device->pipeline_pool2d_f32;
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
case GGML_OP_LEAKY_RELU:
|
case GGML_OP_LEAKY_RELU:
|
||||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||||
return ctx->device->pipeline_leaky_relu_f32;
|
return ctx->device->pipeline_leaky_relu_f32;
|
||||||
|
@ -4464,6 +4483,14 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
||||||
uint32_t half_ceil = (dim + 1) / 2;
|
uint32_t half_ceil = (dim + 1) / 2;
|
||||||
elements = { half_ceil, (uint32_t)src0->ne[0], 1 };
|
elements = { half_ceil, (uint32_t)src0->ne[0], 1 };
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_OP_POOL_2D:
|
||||||
|
{
|
||||||
|
const uint32_t N = dst->ne[3];
|
||||||
|
const uint32_t OC = dst->ne[2];
|
||||||
|
const uint32_t OH = dst->ne[1];
|
||||||
|
const uint32_t OW = dst->ne[0];
|
||||||
|
elements = { N * OC * OH * OW, 1, 1};
|
||||||
|
} break;
|
||||||
case GGML_OP_ADD:
|
case GGML_OP_ADD:
|
||||||
case GGML_OP_DIV:
|
case GGML_OP_DIV:
|
||||||
case GGML_OP_MUL:
|
case GGML_OP_MUL:
|
||||||
|
@ -4914,6 +4941,34 @@ static void ggml_vk_timestep_embedding(ggml_backend_vk_context * ctx, vk_context
|
||||||
}, dryrun);
|
}, dryrun);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
||||||
|
uint32_t op = static_cast<uint32_t>(dst->op_params[0]);
|
||||||
|
const int32_t k1 = dst->op_params[1];
|
||||||
|
const int32_t k0 = dst->op_params[2];
|
||||||
|
const int32_t s1 = dst->op_params[3];
|
||||||
|
const int32_t s0 = dst->op_params[4];
|
||||||
|
const int32_t p1 = dst->op_params[5];
|
||||||
|
const int32_t p0 = dst->op_params[6];
|
||||||
|
|
||||||
|
const uint32_t IH = src0->ne[1];
|
||||||
|
const uint32_t IW = src0->ne[0];
|
||||||
|
|
||||||
|
const uint32_t N = dst->ne[3];
|
||||||
|
|
||||||
|
const uint32_t OC = dst->ne[2];
|
||||||
|
const uint32_t OH = dst->ne[1];
|
||||||
|
const uint32_t OW = dst->ne[0];
|
||||||
|
|
||||||
|
const uint32_t parallel_elements = N * OC * OH * OW;
|
||||||
|
|
||||||
|
ggml_vk_op_f32<vk_op_pool2d_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_POOL_2D, {
|
||||||
|
IW, IH, OW, OH, OC,
|
||||||
|
parallel_elements,
|
||||||
|
op,
|
||||||
|
k0, k1, s0, s1, p0, p1,
|
||||||
|
}, dryrun);
|
||||||
|
}
|
||||||
|
|
||||||
static void ggml_vk_leaky_relu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
static void ggml_vk_leaky_relu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
||||||
const float * op_params = (const float *)dst->op_params;
|
const float * op_params = (const float *)dst->op_params;
|
||||||
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f }, dryrun);
|
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f }, dryrun);
|
||||||
|
@ -5792,6 +5847,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
||||||
case GGML_OP_SUM_ROWS:
|
case GGML_OP_SUM_ROWS:
|
||||||
case GGML_OP_IM2COL:
|
case GGML_OP_IM2COL:
|
||||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||||
|
case GGML_OP_POOL_2D:
|
||||||
case GGML_OP_LEAKY_RELU:
|
case GGML_OP_LEAKY_RELU:
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
|
@ -5927,6 +5983,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
||||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||||
ggml_vk_timestep_embedding(ctx, compute_ctx, src0, node, dryrun);
|
ggml_vk_timestep_embedding(ctx, compute_ctx, src0, node, dryrun);
|
||||||
|
|
||||||
|
break;
|
||||||
|
case GGML_OP_POOL_2D:
|
||||||
|
ggml_vk_pool_2d(ctx, compute_ctx, src0, node, dryrun);
|
||||||
|
|
||||||
break;
|
break;
|
||||||
case GGML_OP_LEAKY_RELU:
|
case GGML_OP_LEAKY_RELU:
|
||||||
ggml_vk_leaky_relu(ctx, compute_ctx, src0, node, dryrun);
|
ggml_vk_leaky_relu(ctx, compute_ctx, src0, node, dryrun);
|
||||||
|
@ -6018,6 +6078,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
|
||||||
case GGML_OP_SUM_ROWS:
|
case GGML_OP_SUM_ROWS:
|
||||||
case GGML_OP_IM2COL:
|
case GGML_OP_IM2COL:
|
||||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||||
|
case GGML_OP_POOL_2D:
|
||||||
case GGML_OP_LEAKY_RELU:
|
case GGML_OP_LEAKY_RELU:
|
||||||
case GGML_OP_REPEAT:
|
case GGML_OP_REPEAT:
|
||||||
buf = tensor->buffer;
|
buf = tensor->buffer;
|
||||||
|
@ -6821,6 +6882,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||||
case GGML_OP_SUM_ROWS:
|
case GGML_OP_SUM_ROWS:
|
||||||
case GGML_OP_IM2COL:
|
case GGML_OP_IM2COL:
|
||||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||||
|
case GGML_OP_POOL_2D:
|
||||||
case GGML_OP_LEAKY_RELU:
|
case GGML_OP_LEAKY_RELU:
|
||||||
return true;
|
return true;
|
||||||
default:
|
default:
|
||||||
|
@ -7334,6 +7396,16 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
||||||
const int32_t dim = tensor->op_params[0];
|
const int32_t dim = tensor->op_params[0];
|
||||||
const int32_t max_period = tensor->op_params[1];
|
const int32_t max_period = tensor->op_params[1];
|
||||||
tensor_clone = ggml_timestep_embedding(ggml_ctx, src0_clone, dim, max_period);
|
tensor_clone = ggml_timestep_embedding(ggml_ctx, src0_clone, dim, max_period);
|
||||||
|
} else if (tensor->op == GGML_OP_POOL_2D) {
|
||||||
|
enum ggml_op_pool op = static_cast<ggml_op_pool>(dst->op_params[0]);
|
||||||
|
const int32_t k0 = tensor->op_params[1];
|
||||||
|
const int32_t k1 = tensor->op_params[2];
|
||||||
|
const int32_t s0 = tensor->op_params[3];
|
||||||
|
const int32_t s1 = tensor->op_params[4];
|
||||||
|
const int32_t p0 = tensor->op_params[5];
|
||||||
|
const int32_t p1 = tensor->op_params[6];
|
||||||
|
|
||||||
|
tensor_clone = ggml_pool_2d(ggml_ctx, src0_clone, op, k0, k1, s0, s1, p0, p1);
|
||||||
} else if (tensor->op == GGML_OP_LEAKY_RELU) {
|
} else if (tensor->op == GGML_OP_LEAKY_RELU) {
|
||||||
const float * op_params = (const float *)tensor->op_params;
|
const float * op_params = (const float *)tensor->op_params;
|
||||||
tensor_clone = ggml_leaky_relu(ggml_ctx, src0_clone, op_params[0], false);
|
tensor_clone = ggml_leaky_relu(ggml_ctx, src0_clone, op_params[0], false);
|
||||||
|
|
|
@ -942,6 +942,36 @@ class tinyBLAS_Q0_AVX {
|
||||||
return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)), _mm_set1_epi8(8));
|
return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)), _mm_set1_epi8(8));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline __m256i load(const block_q5_0 *b) {
|
||||||
|
return _mm256_or_si256(denibble(b->qs), bittobyte(b->qh));
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __m128i load0(const block_q5_0* b) {
|
||||||
|
const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
|
||||||
|
uint32_t x32;
|
||||||
|
memcpy(&x32, b->qh, sizeof(uint32_t));
|
||||||
|
__m128i qxl = _mm_and_si128(_mm_set1_epi8(15), x);
|
||||||
|
__m128i bytesl = _mm_cmpeq_epi8(_mm_set1_epi64x(-1),
|
||||||
|
_mm_or_si128(_mm_set1_epi64x(0x7fbfdfeff7fbfdfe),
|
||||||
|
_mm_shuffle_epi8(_mm_set1_epi32(x32),
|
||||||
|
_mm_set_epi64x(0x0101010101010101, 0x0000000000000000))));
|
||||||
|
bytesl = _mm_andnot_si128(bytesl, _mm_set1_epi8((char)0xF0));
|
||||||
|
return _mm_or_si128(qxl, bytesl);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __m128i load1(const block_q5_0* b) {
|
||||||
|
const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
|
||||||
|
uint32_t x32;
|
||||||
|
memcpy(&x32, b->qh, sizeof(uint32_t));
|
||||||
|
__m128i qxh = _mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4));
|
||||||
|
__m128i bytesh = _mm_cmpeq_epi8(_mm_set1_epi64x(-1),
|
||||||
|
_mm_or_si128(_mm_set1_epi64x(0x7fbfdfeff7fbfdfe),
|
||||||
|
_mm_shuffle_epi8(_mm_set1_epi32(x32),
|
||||||
|
_mm_set_epi64x(0x0303030303030303, 0x0202020202020202))));
|
||||||
|
bytesh = _mm_andnot_si128(bytesh, _mm_set1_epi8((char)0xF0));
|
||||||
|
return _mm_or_si128(qxh, bytesh);
|
||||||
|
}
|
||||||
|
|
||||||
inline __m256i load(const block_iq4_nl *b) {
|
inline __m256i load(const block_iq4_nl *b) {
|
||||||
return MM256_SET_M128I(load1(b), load0(b));
|
return MM256_SET_M128I(load1(b), load0(b));
|
||||||
}
|
}
|
||||||
|
@ -973,6 +1003,17 @@ class tinyBLAS_Q0_AVX {
|
||||||
_mm_srli_epi16(x, 4), 1));
|
_mm_srli_epi16(x, 4), 1));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static inline __m256i bittobyte(const uint8_t *p) {
|
||||||
|
uint32_t x32;
|
||||||
|
memcpy(&x32, p, sizeof(uint32_t));
|
||||||
|
__m256i bytes = _mm256_cmpeq_epi8(_mm256_set1_epi64x(-1),
|
||||||
|
_mm256_or_si256(_mm256_set1_epi64x(0x7fbfdfeff7fbfdfe),
|
||||||
|
_mm256_shuffle_epi8(_mm256_set1_epi32(x32),
|
||||||
|
_mm256_set_epi64x(0x0303030303030303, 0x0202020202020202,
|
||||||
|
0x0101010101010101, 0x0000000000000000))));
|
||||||
|
return _mm256_andnot_si256(bytes, _mm256_set1_epi8((char)0xF0));
|
||||||
|
}
|
||||||
|
|
||||||
const TA *const A;
|
const TA *const A;
|
||||||
const TB *const B;
|
const TB *const B;
|
||||||
TC *const C;
|
TC *const C;
|
||||||
|
@ -1182,6 +1223,22 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
case GGML_TYPE_Q5_0: {
|
||||||
|
if (Btype != GGML_TYPE_Q8_0)
|
||||||
|
return false;
|
||||||
|
#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
|
||||||
|
tinyBLAS_Q0_AVX<block_q5_0, block_q8_0, float> tb{
|
||||||
|
k, (const block_q5_0 *)A, lda,
|
||||||
|
(const block_q8_0 *)B, ldb,
|
||||||
|
(float *)C, ldc,
|
||||||
|
ith, nth};
|
||||||
|
tb.matmul(m, n);
|
||||||
|
return true;
|
||||||
|
#else
|
||||||
|
return false;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
case GGML_TYPE_IQ4_NL: {
|
case GGML_TYPE_IQ4_NL: {
|
||||||
if (Btype != GGML_TYPE_Q8_0)
|
if (Btype != GGML_TYPE_Q8_0)
|
||||||
return false;
|
return false;
|
||||||
|
|
74
ggml/src/vulkan-shaders/pool2d.comp
Normal file
74
ggml/src/vulkan-shaders/pool2d.comp
Normal file
|
@ -0,0 +1,74 @@
|
||||||
|
#version 450
|
||||||
|
|
||||||
|
#include "types.comp"
|
||||||
|
|
||||||
|
#extension GL_EXT_shader_16bit_storage : require
|
||||||
|
|
||||||
|
layout(push_constant) uniform parameter {
|
||||||
|
uint IW; uint IH;
|
||||||
|
uint OW; uint OH;
|
||||||
|
uint OC;
|
||||||
|
uint pelements;
|
||||||
|
uint op;
|
||||||
|
int k0; int k1;
|
||||||
|
int s0; int s1;
|
||||||
|
int p0; int p1;
|
||||||
|
} p;
|
||||||
|
|
||||||
|
#define BLOCK_SIZE 512
|
||||||
|
#define FLT_MAX 3.402823466e+38F
|
||||||
|
#define OP_POOL_MAX 0u
|
||||||
|
#define OP_POOL_AVG 1u
|
||||||
|
|
||||||
|
layout (local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
|
layout(binding = 0) readonly buffer X {A_TYPE data_a[];};
|
||||||
|
layout(binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
const uint idx = gl_GlobalInvocationID.x;
|
||||||
|
if (idx >= p.pelements) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint O_HW = p.OW * p.OH;
|
||||||
|
|
||||||
|
const uint nc = idx / O_HW;
|
||||||
|
const uint cur_oh = (idx % O_HW) / p.OW;
|
||||||
|
const uint cur_ow = (idx % O_HW) % p.OW;
|
||||||
|
|
||||||
|
const int start_h = int(cur_oh) * p.s0 - p.p0;
|
||||||
|
const uint bh = max(start_h, 0);
|
||||||
|
const uint eh = min(start_h + p.k0, p.IH);
|
||||||
|
|
||||||
|
const int start_w = int(cur_ow) * p.s1 - p.p1;
|
||||||
|
const uint bw = max(start_w, 0);
|
||||||
|
const uint ew = min(start_w + p.k1, p.IW);
|
||||||
|
|
||||||
|
const float scale = 1.0 / float(p.k0 * p.k1);
|
||||||
|
float res;
|
||||||
|
|
||||||
|
if (p.op == OP_POOL_AVG) {
|
||||||
|
res = 0.0;
|
||||||
|
} else if (p.op == OP_POOL_MAX) {
|
||||||
|
res = -FLT_MAX;
|
||||||
|
} else {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (uint i = bh; i < eh; i++) {
|
||||||
|
#pragma unroll
|
||||||
|
for (uint j = bw; j < ew; j++) {
|
||||||
|
const float cur = D_TYPE(data_a[nc * p.IH * p.IW + i * p.IW + j]);
|
||||||
|
|
||||||
|
if (p.op == OP_POOL_AVG) {
|
||||||
|
res += cur * scale;
|
||||||
|
} else if (p.op == OP_POOL_MAX) {
|
||||||
|
res = max(res, cur);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
data_d[nc * O_HW + cur_oh * p.OW + cur_ow] = res;
|
||||||
|
}
|
|
@ -494,6 +494,10 @@ void process_shaders(std::vector<std::future<void>>& tasks) {
|
||||||
tasks.push_back(std::async(std::launch::async, [=] {
|
tasks.push_back(std::async(std::launch::async, [=] {
|
||||||
string_to_spv("timestep_embedding_f32", "timestep_embedding.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
string_to_spv("timestep_embedding_f32", "timestep_embedding.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||||
}));
|
}));
|
||||||
|
|
||||||
|
tasks.push_back(std::async(std::launch::async, [=] {
|
||||||
|
string_to_spv("pool2d_f32", "pool2d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||||
|
}));
|
||||||
}
|
}
|
||||||
|
|
||||||
void write_output_files() {
|
void write_output_files() {
|
||||||
|
|
|
@ -1089,9 +1089,6 @@ extern "C" {
|
||||||
/// @details Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841
|
/// @details Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841
|
||||||
LLAMA_API struct llama_sampler * llama_sampler_init_min_p (float p, size_t min_keep);
|
LLAMA_API struct llama_sampler * llama_sampler_init_min_p (float p, size_t min_keep);
|
||||||
|
|
||||||
/// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
|
|
||||||
LLAMA_API struct llama_sampler * llama_sampler_init_tail_free (float z, size_t min_keep);
|
|
||||||
|
|
||||||
/// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
|
/// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
|
||||||
LLAMA_API struct llama_sampler * llama_sampler_init_typical (float p, size_t min_keep);
|
LLAMA_API struct llama_sampler * llama_sampler_init_typical (float p, size_t min_keep);
|
||||||
|
|
||||||
|
@ -1143,6 +1140,16 @@ extern "C" {
|
||||||
bool penalize_nl, // consider newlines as a repeatable token
|
bool penalize_nl, // consider newlines as a repeatable token
|
||||||
bool ignore_eos); // ignore the end-of-sequence token
|
bool ignore_eos); // ignore the end-of-sequence token
|
||||||
|
|
||||||
|
/// @details DRY sampler, designed by p-e-w, as described in: https://github.com/oobabooga/text-generation-webui/pull/5677, porting Koboldcpp implementation authored by pi6am: https://github.com/LostRuins/koboldcpp/pull/982
|
||||||
|
LLAMA_API struct llama_sampler * llama_sampler_init_dry(
|
||||||
|
const struct llama_model * model,
|
||||||
|
float dry_multiplier,
|
||||||
|
float dry_base,
|
||||||
|
int32_t dry_allowed_length,
|
||||||
|
int32_t dry_penalty_last_n,
|
||||||
|
const char ** seq_breakers,
|
||||||
|
size_t num_breakers);
|
||||||
|
|
||||||
LLAMA_API struct llama_sampler * llama_sampler_init_logit_bias(
|
LLAMA_API struct llama_sampler * llama_sampler_init_logit_bias(
|
||||||
int32_t n_vocab,
|
int32_t n_vocab,
|
||||||
int32_t n_logit_bias,
|
int32_t n_logit_bias,
|
||||||
|
|
|
@ -113,7 +113,7 @@ static void llama_sampler_softmax_impl(llama_token_data_array * cur_p) {
|
||||||
}
|
}
|
||||||
|
|
||||||
static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k) {
|
static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k) {
|
||||||
// TODO: move bucket sort to separate function so that top_p/tail_free/typical/softmax first is equally fast
|
// TODO: move bucket sort to separate function so that top_p/typical/softmax first is equally fast
|
||||||
// if (k >= (int32_t)cur_p->size) {
|
// if (k >= (int32_t)cur_p->size) {
|
||||||
// return;
|
// return;
|
||||||
// }
|
// }
|
||||||
|
@ -733,101 +733,6 @@ struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) {
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
// tail-free
|
|
||||||
|
|
||||||
struct llama_sampler_tail_free {
|
|
||||||
const float z;
|
|
||||||
const size_t min_keep;
|
|
||||||
};
|
|
||||||
|
|
||||||
static const char * llama_sampler_tail_free_name(const struct llama_sampler * /*smpl*/) {
|
|
||||||
return "tail-free";
|
|
||||||
}
|
|
||||||
|
|
||||||
static void llama_sampler_tail_free_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
|
||||||
const auto * ctx = (llama_sampler_tail_free *) smpl->ctx;
|
|
||||||
|
|
||||||
if (ctx->z >= 1.0f || cur_p->size <= 2) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
llama_sampler_softmax_impl(cur_p);
|
|
||||||
|
|
||||||
// Compute the first and second derivatives
|
|
||||||
std::vector<float> first_derivatives(cur_p->size - 1);
|
|
||||||
std::vector<float> second_derivatives(cur_p->size - 2);
|
|
||||||
|
|
||||||
for (size_t i = 0; i < first_derivatives.size(); ++i) {
|
|
||||||
first_derivatives[i] = cur_p->data[i].p - cur_p->data[i + 1].p;
|
|
||||||
}
|
|
||||||
for (size_t i = 0; i < second_derivatives.size(); ++i) {
|
|
||||||
second_derivatives[i] = first_derivatives[i] - first_derivatives[i + 1];
|
|
||||||
}
|
|
||||||
|
|
||||||
// Calculate absolute value of second derivatives
|
|
||||||
for (size_t i = 0; i < second_derivatives.size(); ++i) {
|
|
||||||
second_derivatives[i] = std::abs(second_derivatives[i]);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Normalize the second derivatives
|
|
||||||
{
|
|
||||||
const float second_derivatives_sum = std::accumulate(second_derivatives.begin(), second_derivatives.end(), 0.0f);
|
|
||||||
|
|
||||||
if (second_derivatives_sum > 1e-6f) {
|
|
||||||
for (float & value : second_derivatives) {
|
|
||||||
value /= second_derivatives_sum;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (float & value : second_derivatives) {
|
|
||||||
value = 1.0f / second_derivatives.size();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
float cum_sum = 0.0f;
|
|
||||||
size_t last_idx = cur_p->size;
|
|
||||||
for (size_t i = 0; i < second_derivatives.size(); ++i) {
|
|
||||||
cum_sum += second_derivatives[i];
|
|
||||||
|
|
||||||
// Check if the running sum is greater than z or if we have kept at least min_keep tokens
|
|
||||||
if (cum_sum > ctx->z && i >= ctx->min_keep) {
|
|
||||||
last_idx = i;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Resize the output vector to keep only the tokens above the tail location
|
|
||||||
cur_p->size = last_idx;
|
|
||||||
}
|
|
||||||
|
|
||||||
static struct llama_sampler * llama_sampler_tail_free_clone(const struct llama_sampler * smpl) {
|
|
||||||
const auto * ctx = (const llama_sampler_tail_free *) smpl->ctx;
|
|
||||||
return llama_sampler_init_tail_free(ctx->z, ctx->min_keep);
|
|
||||||
}
|
|
||||||
|
|
||||||
static void llama_sampler_tail_free_free(struct llama_sampler * smpl) {
|
|
||||||
delete (llama_sampler_tail_free *) smpl->ctx;
|
|
||||||
}
|
|
||||||
|
|
||||||
static struct llama_sampler_i llama_sampler_tail_free_i = {
|
|
||||||
/* .name = */ llama_sampler_tail_free_name,
|
|
||||||
/* .accept = */ nullptr,
|
|
||||||
/* .apply = */ llama_sampler_tail_free_apply,
|
|
||||||
/* .reset = */ nullptr,
|
|
||||||
/* .clone = */ llama_sampler_tail_free_clone,
|
|
||||||
/* .free = */ llama_sampler_tail_free_free,
|
|
||||||
};
|
|
||||||
|
|
||||||
struct llama_sampler * llama_sampler_init_tail_free(float z, size_t min_keep) {
|
|
||||||
return new llama_sampler {
|
|
||||||
/* .iface = */ &llama_sampler_tail_free_i,
|
|
||||||
/* .ctx = */ new llama_sampler_tail_free {
|
|
||||||
/* .z = */ z,
|
|
||||||
/*. min_keep = */ min_keep,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
// typical
|
// typical
|
||||||
|
|
||||||
struct llama_sampler_typical {
|
struct llama_sampler_typical {
|
||||||
|
@ -1683,6 +1588,397 @@ struct llama_sampler * llama_sampler_init_penalties(
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DRY
|
||||||
|
|
||||||
|
struct llama_sampler_dry {
|
||||||
|
int32_t total_context_size;
|
||||||
|
|
||||||
|
const float dry_multiplier;
|
||||||
|
const float dry_base;
|
||||||
|
const int32_t dry_allowed_length;
|
||||||
|
const int32_t dry_penalty_last_n;
|
||||||
|
|
||||||
|
std::unordered_multimap<llama_token, std::vector<llama_token>> dry_processed_breakers;
|
||||||
|
std::vector<int> dry_repeat_count;
|
||||||
|
std::unordered_map<llama_token, int> dry_max_token_repeat;
|
||||||
|
ring_buffer<llama_token> last_tokens;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am)
|
||||||
|
static void get_overlapping_token_sequences(const llama_vocab & vocab, const std::string& str, std::unordered_multimap<llama_token, std::vector<llama_token>>& token_sequences, int max_tail_len = -1) {
|
||||||
|
for (llama_token token_id = 0; token_id < (llama_token)vocab.n_vocab; token_id++) {
|
||||||
|
std::string word = llama_detokenize(vocab, {token_id}, true);
|
||||||
|
if (word.find(str) != std::string::npos) {
|
||||||
|
token_sequences.emplace(token_id, std::vector<llama_token>());
|
||||||
|
} else {
|
||||||
|
size_t word_len = word.size(), str_len = str.size();
|
||||||
|
size_t pos = -1;
|
||||||
|
while ((pos = word.find(str[0], pos + 1)) != std::string::npos) {
|
||||||
|
bool match = true;
|
||||||
|
size_t i;
|
||||||
|
for (i = 1; i < str_len && i + pos < word_len; ++i) {
|
||||||
|
if (word[pos + i] != str[i]) {
|
||||||
|
match = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (match) {
|
||||||
|
std::vector<llama_token> tokenization = llama_tokenize_internal(vocab, str.substr(i), false, false);
|
||||||
|
if (max_tail_len >= 0 && tokenization.size() > (size_t)max_tail_len) {
|
||||||
|
tokenization.resize(max_tail_len);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure we don't already have a duplicate matching tokenization
|
||||||
|
auto its = token_sequences.equal_range(token_id);
|
||||||
|
bool found = false;
|
||||||
|
for (auto it = its.first; it != its.second; ++it) {
|
||||||
|
if (tokenization == it->second) {
|
||||||
|
found = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!found) {
|
||||||
|
token_sequences.emplace(token_id, tokenization);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static const char * llama_sampler_dry_name(const struct llama_sampler * /*smpl*/) {
|
||||||
|
return "dry";
|
||||||
|
}
|
||||||
|
|
||||||
|
static void llama_sampler_dry_accept(struct llama_sampler * smpl, llama_token token) {
|
||||||
|
auto * ctx = (llama_sampler_dry *) smpl->ctx;
|
||||||
|
if (ctx->dry_multiplier == 0.0f || ctx->dry_base < 1.0f || ctx->dry_penalty_last_n == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx->last_tokens.push_back(token);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am)
|
||||||
|
static void llama_sampler_dry_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
||||||
|
auto * ctx = (llama_sampler_dry *) smpl->ctx;
|
||||||
|
|
||||||
|
if (ctx->dry_multiplier == 0.0f || ctx->dry_base < 1.0f || ctx->dry_penalty_last_n == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t effective_dry_penalty_last_n = (ctx->dry_penalty_last_n == -1) ? ctx->total_context_size : std::max(ctx->dry_penalty_last_n, 0);
|
||||||
|
int last_n_repeat = std::min(std::min((int)ctx->last_tokens.size(), effective_dry_penalty_last_n), ctx->total_context_size);
|
||||||
|
|
||||||
|
if (last_n_repeat <= ctx->dry_allowed_length) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx->dry_repeat_count.assign(last_n_repeat, 0);
|
||||||
|
ctx->dry_max_token_repeat.clear();
|
||||||
|
|
||||||
|
// Step 1: Look for restart sequences to limit the maximum repetition length.
|
||||||
|
// Work backwards through the context looking for any token that begins a restart sequence.
|
||||||
|
//
|
||||||
|
// The collection `restart_sequences` is a mapping from a "head" token to all "tail"
|
||||||
|
// sequences that together comprise a restart sequence. This allows us to quickly check
|
||||||
|
// whether each token is the head of a complete sequence. Most restart sequences are actually
|
||||||
|
// a single token, and for these the "tail" is an empty vector.
|
||||||
|
//
|
||||||
|
// If the token is a "head", test all restart sequences that begin with this token
|
||||||
|
// (there will often only be one sequence for each token, but if sequences like 'aaaq1' and
|
||||||
|
// 'aaa1' are used as restart strings, both could start with 'aaa' when tokenized). The
|
||||||
|
// longest matching sequence (if any) is used to limit the maximum repetition length.
|
||||||
|
//
|
||||||
|
// Note that in the case case of a short sequence contained in a longer one, this might fail to
|
||||||
|
// find the smallest value for `rep_limit`. For example, if 'amniotic' and 'ni' are both used as
|
||||||
|
// restart sequences, 'ni' will be found first, and since it's shorter it will fail to suppress
|
||||||
|
// 'otic'. This is a minor issue since fully contained restart sequences are likely to be rare.
|
||||||
|
//
|
||||||
|
// This is theoretically worst-case O(N^2) for arbitrary restart sequences, which is why we
|
||||||
|
// have already clamped the maximum tail sequence length when generating `restart_sequences`.
|
||||||
|
// With clamping, this scan is O(N) in the context length.
|
||||||
|
|
||||||
|
int rep_limit = last_n_repeat;
|
||||||
|
for (int i = 0; i < last_n_repeat; ++i) {
|
||||||
|
llama_token token = ctx->last_tokens.rat(i);
|
||||||
|
auto its = ctx->dry_processed_breakers.equal_range(token);
|
||||||
|
if (its.first == ctx->dry_processed_breakers.end()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
int longest_match = -1;
|
||||||
|
for (auto it = its.first; it != its.second; ++it) {
|
||||||
|
// Note that (*it) does not contain the head character, so seq_len will be
|
||||||
|
// the restart sequence length minus 1.
|
||||||
|
// In the common case of a single-token restart sequence, (*it) will be empty
|
||||||
|
// and we will trivially match.
|
||||||
|
int seq_len = (int)it->second.size();
|
||||||
|
if (seq_len > longest_match && seq_len <= (int)i) {
|
||||||
|
bool match = true;
|
||||||
|
for (int offset = 0; offset < seq_len; ++offset) {
|
||||||
|
// The -1 when indexing `last_tokens` is because we already matched the head.
|
||||||
|
if (it->second[offset] != ctx->last_tokens.rat(i - offset - 1)) {
|
||||||
|
match = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (match) {
|
||||||
|
longest_match = seq_len;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (longest_match >= 0) {
|
||||||
|
// We found a restart sequence starting `i` tokens from the end and continuing for
|
||||||
|
// `longest_match` tokens.
|
||||||
|
rep_limit = i - longest_match;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (rep_limit < ctx->dry_allowed_length) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 2: Iterate in reverse over the last N tokens of the context, using the "Z-algorithm" (in
|
||||||
|
// the reverse direction) to efficiently compute the positions and lengths of suffixes appearing
|
||||||
|
// elsewhere in the context. We limit the suffix length to `rep_limit` to respect restart sequences.
|
||||||
|
//
|
||||||
|
// This algorithm is not currently documented on Wikipedia, but there is a clear description here:
|
||||||
|
// https://ivanyu.me/blog/2014/10/15/z-algorithm/
|
||||||
|
//
|
||||||
|
// The code below is adapted from the public domain implementation by the same author here:
|
||||||
|
// https://github.com/ivanyu/string-algorithms/blob/master/z_algorithm.py
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
// Last N tokens: a b c c b c y a b c
|
||||||
|
// Repeat counts: 0 0 3 1 0 2 0 0 0 0
|
||||||
|
// ^
|
||||||
|
// This `3` means that the last three tokens of the context (a b c) also appear here.
|
||||||
|
//
|
||||||
|
// This step is worst case O(N) since the Z-algorithm is linear, despite the appearance of nested
|
||||||
|
// for/while loops. This can be seen by observing that the `lt` and `rt` bounds are set after each
|
||||||
|
// repeated suffix is detected (i.e. after each while loop when n > 0). These bound variables
|
||||||
|
// ensure that the inner while loops only examine each token in the context once as the outer
|
||||||
|
// for loop iterates over the context.
|
||||||
|
|
||||||
|
{
|
||||||
|
const int last = last_n_repeat - 1;
|
||||||
|
int rt = 0, lt = 0;
|
||||||
|
|
||||||
|
for (int k = 1; k < last_n_repeat; ++k) {
|
||||||
|
if (k > rt) {
|
||||||
|
// If k is outside the current Z-box, do naive computation.
|
||||||
|
int n = 0;
|
||||||
|
while (n + k < last_n_repeat && ctx->last_tokens.rat(n) == ctx->last_tokens.rat(n+k)) {
|
||||||
|
++n;
|
||||||
|
}
|
||||||
|
ctx->dry_repeat_count[last - k] = std::min(n, rep_limit);
|
||||||
|
if (n > 0) {
|
||||||
|
lt = k;
|
||||||
|
rt = k+n-1;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// If k is inside the current Z-box, consider two cases.
|
||||||
|
|
||||||
|
int p = k - lt; // Pair index.
|
||||||
|
int right_part_len = rt - k + 1;
|
||||||
|
|
||||||
|
if (ctx->dry_repeat_count[last - p] < right_part_len) {
|
||||||
|
int n = std::min(ctx->dry_repeat_count[last - p], rep_limit);
|
||||||
|
ctx->dry_repeat_count[last - k] = n;
|
||||||
|
} else {
|
||||||
|
int i = rt + 1;
|
||||||
|
while (i < last_n_repeat && ctx->last_tokens.rat(i) == ctx->last_tokens.rat(i - k)) {
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
int n = std::min(i - k, rep_limit);
|
||||||
|
ctx->dry_repeat_count[last - k] = n;
|
||||||
|
lt = k;
|
||||||
|
rt = i - 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 3: Iterate over dry_repeat_count and last_tokens, examining the maximum repeat length
|
||||||
|
// that would be generated by emitting each new token that would extend a sequence.
|
||||||
|
//
|
||||||
|
// Following the same example as above:
|
||||||
|
// Last N tokens: a b c c b c y a b c
|
||||||
|
// Repeat counts: 0 0 3 1 0 2 0 0 0 0
|
||||||
|
//
|
||||||
|
// For each non-zero, look ahead one token. This token, if emitted, would extend the repetition.
|
||||||
|
// c: 3 -> 4 (from `a b c` to `a b c c`)
|
||||||
|
// b: 1 -> 2 (from `c` to `c b`)
|
||||||
|
// y: 2 -> 3 (from `b c` to `b c y`)
|
||||||
|
|
||||||
|
for (int i = 0; i < last_n_repeat - 1; ++i) {
|
||||||
|
int repeat_len = ctx->dry_repeat_count[i];
|
||||||
|
if (repeat_len >= ctx->dry_allowed_length) {
|
||||||
|
// This token ends a repeat, so the next token would continue one.
|
||||||
|
// By convention, the value of `repeat_len` only includes the tokens currently
|
||||||
|
// in the context, not the new token that would be added.
|
||||||
|
llama_token token = ctx->last_tokens.rat(last_n_repeat - 2 - i);
|
||||||
|
// Track the maximum sequence ending in this token.
|
||||||
|
const auto& it = ctx->dry_max_token_repeat.find(token);
|
||||||
|
if (it == ctx->dry_max_token_repeat.end() || it->second < repeat_len) {
|
||||||
|
ctx->dry_max_token_repeat[token] = repeat_len;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 4: Apply logit penalties based on the maximum repeat length for relevant tokens.
|
||||||
|
|
||||||
|
// Prevent floating point overflow in `pow(penalty_base, exponent)` by clamping to `max_exponent`.
|
||||||
|
// Compute it from `penalty_base` and the approximate log of `std::numeric_limits<float>::max()`
|
||||||
|
const float FLOAT_MAX_LOG = 88.7228391f;
|
||||||
|
int max_exponent = 0;
|
||||||
|
if (ctx->dry_base > 1.000001f) {
|
||||||
|
max_exponent = FLOAT_MAX_LOG / std::log(ctx->dry_base);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||||
|
const auto& af_kvp = ctx->dry_max_token_repeat.find(cur_p->data[i].id);
|
||||||
|
if (af_kvp != ctx->dry_max_token_repeat.end()) {
|
||||||
|
// Check all sequence breakers starting with this token
|
||||||
|
auto range = ctx->dry_processed_breakers.equal_range(cur_p->data[i].id);
|
||||||
|
bool is_single_token_breaker = false;
|
||||||
|
|
||||||
|
for (auto it = range.first; it != range.second; ++it) {
|
||||||
|
if (it->second.empty()) {
|
||||||
|
is_single_token_breaker = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply penalty only if it's not a single-token sequence breaker
|
||||||
|
if (!is_single_token_breaker) {
|
||||||
|
int repeat_exp = af_kvp->second - ctx->dry_allowed_length;
|
||||||
|
if (max_exponent > 0 && repeat_exp > max_exponent) {
|
||||||
|
repeat_exp = max_exponent;
|
||||||
|
}
|
||||||
|
float penalty = ctx->dry_multiplier * std::pow(ctx->dry_base, repeat_exp);
|
||||||
|
cur_p->data[i].logit -= penalty;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cur_p->sorted = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void llama_sampler_dry_reset(struct llama_sampler * smpl) {
|
||||||
|
auto * ctx = (llama_sampler_dry *) smpl->ctx;
|
||||||
|
ctx->last_tokens.clear();
|
||||||
|
ctx->dry_repeat_count.clear();
|
||||||
|
ctx->dry_max_token_repeat.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
static struct llama_sampler * llama_sampler_dry_clone(const struct llama_sampler * smpl) {
|
||||||
|
const auto * ctx = (llama_sampler_dry *) smpl->ctx;
|
||||||
|
|
||||||
|
// nullptr is passed as vocab because it is only needed for raw sequence breaker processing, which we have already done and will be copying
|
||||||
|
auto * result = llama_sampler_init_dry(nullptr, ctx->dry_multiplier, ctx->dry_base, ctx->dry_allowed_length, ctx->dry_penalty_last_n, NULL, 0);
|
||||||
|
// Copy the state, including the processed breakers
|
||||||
|
{
|
||||||
|
auto * result_ctx = (llama_sampler_dry *) result->ctx;
|
||||||
|
result_ctx->dry_processed_breakers = ctx->dry_processed_breakers;
|
||||||
|
result_ctx->dry_repeat_count = ctx->dry_repeat_count;
|
||||||
|
result_ctx->dry_max_token_repeat = ctx->dry_max_token_repeat;
|
||||||
|
result_ctx->last_tokens = ctx->last_tokens;
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void llama_sampler_dry_free(struct llama_sampler * smpl) {
|
||||||
|
delete (llama_sampler_dry *) smpl->ctx;
|
||||||
|
}
|
||||||
|
|
||||||
|
static struct llama_sampler_i llama_sampler_dry_i = {
|
||||||
|
/* .name = */ llama_sampler_dry_name,
|
||||||
|
/* .accept = */ llama_sampler_dry_accept,
|
||||||
|
/* .apply = */ llama_sampler_dry_apply,
|
||||||
|
/* .reset = */ llama_sampler_dry_reset,
|
||||||
|
/* .clone = */ llama_sampler_dry_clone,
|
||||||
|
/* .free = */ llama_sampler_dry_free,
|
||||||
|
};
|
||||||
|
|
||||||
|
struct llama_sampler * llama_sampler_init_dry_impl(const struct llama_vocab & vocab, int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) {
|
||||||
|
int32_t effective_dry_penalty_last_n = (dry_penalty_last_n == -1) ? context_size : std::max(dry_penalty_last_n, 0);
|
||||||
|
std::unordered_multimap<llama_token, std::vector<llama_token>> processed_breakers;
|
||||||
|
const int MAX_CHAR_LEN = 40;
|
||||||
|
const int MAX_SEQ_LEN = 20;
|
||||||
|
|
||||||
|
const bool dry_enabled = (dry_multiplier != 0.0f && dry_base >= 1.0f && dry_penalty_last_n != 0);
|
||||||
|
|
||||||
|
if (dry_enabled && seq_breakers != nullptr && num_breakers > 0) {
|
||||||
|
// Process sequence breakers
|
||||||
|
for (size_t i = 0; i < num_breakers; ++i) {
|
||||||
|
if (seq_breakers[i] == nullptr || std::strlen(seq_breakers[i]) == 0) {
|
||||||
|
LLAMA_LOG_WARN("skipping null or empty DRY sequence breaker at index %zu\n", i);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string sequence_break(seq_breakers[i]);
|
||||||
|
if (sequence_break.empty()) {
|
||||||
|
LLAMA_LOG_WARN("skipping empty DRY sequence breaker\n");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (sequence_break.size() > MAX_CHAR_LEN) {
|
||||||
|
LLAMA_LOG_WARN("truncating DRY sequence breaker to %d characters\n", MAX_CHAR_LEN);
|
||||||
|
sequence_break.resize(MAX_CHAR_LEN);
|
||||||
|
}
|
||||||
|
|
||||||
|
get_overlapping_token_sequences(vocab, sequence_break, processed_breakers, MAX_SEQ_LEN);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return new llama_sampler {
|
||||||
|
/* .iface = */ &llama_sampler_dry_i,
|
||||||
|
/* .ctx = */ new llama_sampler_dry {
|
||||||
|
/* .total_context_size = */ context_size,
|
||||||
|
/* .dry_multiplier = */ dry_multiplier,
|
||||||
|
/* .dry_base = */ dry_base,
|
||||||
|
/* .dry_allowed_length = */ dry_allowed_length,
|
||||||
|
/* .dry_penalty_last_n = */ dry_penalty_last_n,
|
||||||
|
/* .dry_processed_breakers = */ std::move(processed_breakers),
|
||||||
|
/* .dry_repeat_count = */ dry_enabled ? std::vector<int>(effective_dry_penalty_last_n, 0) : std::vector<int>{},
|
||||||
|
/* .dry_max_token_repeat = */ {},
|
||||||
|
/* .last_tokens = */ dry_enabled ? ring_buffer<llama_token>(effective_dry_penalty_last_n) : ring_buffer<llama_token>(0),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
// wrapper for test-sampling.cpp
|
||||||
|
struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const std::vector<std::vector<llama_token>>& seq_breakers) {
|
||||||
|
llama_vocab dummy_vocab;
|
||||||
|
auto * result = llama_sampler_init_dry_impl(dummy_vocab, context_size, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, NULL, 0);
|
||||||
|
auto * ctx = (llama_sampler_dry *) result->ctx;
|
||||||
|
|
||||||
|
// Process the token-based sequence breakers
|
||||||
|
ctx->dry_processed_breakers.clear();
|
||||||
|
if (seq_breakers.empty()) {
|
||||||
|
LLAMA_LOG_WARN("empty DRY sequence breakers list in llama_sampler_init_dry_testing\n");
|
||||||
|
} else {
|
||||||
|
for (const auto& breaker : seq_breakers) {
|
||||||
|
if (breaker.empty()) {
|
||||||
|
LLAMA_LOG_WARN("skipping DRY empty sequence breaker\n");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
llama_token head_token = breaker[0];
|
||||||
|
std::vector<llama_token> tail_tokens(breaker.begin() + 1, breaker.end());
|
||||||
|
ctx->dry_processed_breakers.emplace(head_token, std::move(tail_tokens));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ctx->dry_processed_breakers.empty()) {
|
||||||
|
LLAMA_LOG_WARN("no valid DRY sequence breakers processed in llama_sampler_init_dry_testing\n");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
// logit-bias
|
// logit-bias
|
||||||
|
|
||||||
struct llama_sampler_logit_bias {
|
struct llama_sampler_logit_bias {
|
||||||
|
|
|
@ -28,3 +28,21 @@ struct llama_sampler * llama_sampler_init_grammar_impl(
|
||||||
|
|
||||||
struct llama_sampler * llama_sampler_init_infill_impl(
|
struct llama_sampler * llama_sampler_init_infill_impl(
|
||||||
const struct llama_vocab & vocab);
|
const struct llama_vocab & vocab);
|
||||||
|
|
||||||
|
struct llama_sampler * llama_sampler_init_dry_impl(
|
||||||
|
const struct llama_vocab & vocab,
|
||||||
|
int32_t context_size,
|
||||||
|
float dry_multiplier,
|
||||||
|
float dry_base,
|
||||||
|
int32_t dry_allowed_length,
|
||||||
|
int32_t dry_penalty_last_n,
|
||||||
|
const char ** seq_breakers,
|
||||||
|
size_t num_breakers);
|
||||||
|
|
||||||
|
struct llama_sampler * llama_sampler_init_dry_testing(
|
||||||
|
int32_t context_size,
|
||||||
|
float dry_multiplier,
|
||||||
|
float dry_base,
|
||||||
|
int32_t dry_allowed_length,
|
||||||
|
int32_t dry_penalty_last_n,
|
||||||
|
const std::vector<std::vector<llama_token>>& seq_breakers);
|
||||||
|
|
|
@ -2226,3 +2226,19 @@ int32_t llama_detokenize_impl(
|
||||||
|
|
||||||
return total <= text_len_max ? total : -total;
|
return total <= text_len_max ? total : -total;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string llama_detokenize(const struct llama_vocab & vocab, const std::vector<llama_token> & tokens, bool special) {
|
||||||
|
std::string text;
|
||||||
|
text.resize(std::max(text.capacity(), tokens.size()));
|
||||||
|
int32_t n_chars = llama_detokenize_impl(vocab, tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
|
||||||
|
if (n_chars < 0) {
|
||||||
|
text.resize(-n_chars);
|
||||||
|
n_chars = llama_detokenize_impl(vocab, tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
|
||||||
|
GGML_ASSERT(n_chars <= (int32_t)text.size()); // whitespace trimming is performed after per-token detokenization
|
||||||
|
}
|
||||||
|
|
||||||
|
text.resize(n_chars);
|
||||||
|
|
||||||
|
// NOTE: the original tokenizer decodes bytes after collecting the pieces.
|
||||||
|
return text;
|
||||||
|
}
|
||||||
|
|
|
@ -163,3 +163,8 @@ int32_t llama_detokenize_impl(
|
||||||
int32_t text_len_max,
|
int32_t text_len_max,
|
||||||
bool remove_special,
|
bool remove_special,
|
||||||
bool unparse_special);
|
bool unparse_special);
|
||||||
|
|
||||||
|
std::string llama_detokenize(
|
||||||
|
const struct llama_vocab & vocab,
|
||||||
|
const std::vector<llama_token> & tokens,
|
||||||
|
bool special);
|
||||||
|
|
|
@ -9670,20 +9670,16 @@ static struct ggml_tensor * llm_build_kqv(
|
||||||
cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
|
cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
|
||||||
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
|
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
|
||||||
|
|
||||||
if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_GEMMA2) {
|
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
|
||||||
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
|
|
||||||
}
|
|
||||||
|
|
||||||
cur = ggml_reshape_2d(ctx, cur, n_embd_head_v*n_head, n_tokens);
|
cur = ggml_reshape_2d(ctx, cur, n_embd_head_v*n_head, n_tokens);
|
||||||
} else {
|
} else {
|
||||||
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
|
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
|
||||||
cb(kq, "kq", il);
|
cb(kq, "kq", il);
|
||||||
|
|
||||||
if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2 || model.arch == LLM_ARCH_NEMOTRON || model.arch == LLM_ARCH_CHATGLM) {
|
// note: this op tends to require high floating point range
|
||||||
// for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs
|
// while for some models F16 is enough, for others it is not, so we default to F32 here
|
||||||
// ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847
|
ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
|
||||||
ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (model.arch == LLM_ARCH_GROK) {
|
if (model.arch == LLM_ARCH_GROK) {
|
||||||
// need to do the following:
|
// need to do the following:
|
||||||
|
@ -9692,9 +9688,6 @@ static struct ggml_tensor * llm_build_kqv(
|
||||||
// kq = 30 * tanh(kq / 30)
|
// kq = 30 * tanh(kq / 30)
|
||||||
// before the softmax below
|
// before the softmax below
|
||||||
|
|
||||||
//try from phi2
|
|
||||||
//ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
|
|
||||||
|
|
||||||
kq = ggml_tanh(ctx, ggml_scale(ctx, kq, 0.08838834764831845f/30.0f));
|
kq = ggml_tanh(ctx, ggml_scale(ctx, kq, 0.08838834764831845f/30.0f));
|
||||||
kq = ggml_scale(ctx, kq, 30);
|
kq = ggml_scale(ctx, kq, 30);
|
||||||
}
|
}
|
||||||
|
@ -21793,6 +21786,16 @@ static int32_t llama_chat_apply_template_internal(
|
||||||
ss << message->content << "\n\n";
|
ss << message->content << "\n\n";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} else if (tmpl == "granite" || tmpl_contains("<|start_of_role|>")) {
|
||||||
|
// IBM Granite template
|
||||||
|
for (const auto & message : chat) {
|
||||||
|
std::string role(message->role);
|
||||||
|
ss << "<|start_of_role|>" << role << "<|end_of_role|>"
|
||||||
|
<< message->content << "<|end_of_text|>\n";
|
||||||
|
}
|
||||||
|
if (add_ass) {
|
||||||
|
ss << "<|start_of_role|>assistant<|end_of_role|>\n";
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
// template not supported
|
// template not supported
|
||||||
return -1;
|
return -1;
|
||||||
|
@ -21855,6 +21858,10 @@ struct llama_sampler * llama_sampler_init_infill(const struct llama_model * mode
|
||||||
return llama_sampler_init_infill_impl(model->vocab);
|
return llama_sampler_init_infill_impl(model->vocab);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct llama_sampler * llama_sampler_init_dry(const struct llama_model * model, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) {
|
||||||
|
return llama_sampler_init_dry_impl(model->vocab, llama_n_ctx_train(model), dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, seq_breakers, num_breakers);
|
||||||
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// model split
|
// model split
|
||||||
//
|
//
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue