mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-11 01:24:36 +00:00
merged, try to fix metal build
This commit is contained in:
commit
ec5dea14d7
29 changed files with 1541 additions and 967 deletions
|
@ -148,7 +148,7 @@ struct server_slot {
|
|||
int32_t n_decoded = 0;
|
||||
int32_t n_remaining = -1;
|
||||
int32_t i_batch = -1;
|
||||
int32_t n_predict = -1;
|
||||
int32_t n_predict = -1; // TODO: disambiguate from params.n_predict
|
||||
|
||||
int32_t n_prompt_tokens = 0;
|
||||
int32_t n_prompt_tokens_processed = 0;
|
||||
|
@ -740,7 +740,13 @@ struct server_context {
|
|||
default_generation_settings_for_props = get_formated_generation(slots.front());
|
||||
default_generation_settings_for_props["seed"] = -1;
|
||||
|
||||
batch = llama_batch_init(n_ctx, 0, params.n_parallel);
|
||||
// the update_slots() logic will always submit a maximum of n_batch tokens
|
||||
// note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used)
|
||||
{
|
||||
const int32_t n_batch = llama_n_batch(ctx);
|
||||
|
||||
batch = llama_batch_init(n_batch, 0, params.n_parallel);
|
||||
}
|
||||
|
||||
metrics.init();
|
||||
}
|
||||
|
@ -1037,8 +1043,10 @@ struct server_context {
|
|||
llama_batch_add(batch, system_tokens[i], i, { 0 }, false);
|
||||
}
|
||||
|
||||
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += params.n_batch) {
|
||||
const int32_t n_tokens = std::min(params.n_batch, (int32_t) (batch.n_tokens - i));
|
||||
const int32_t n_batch = llama_n_batch(ctx);
|
||||
|
||||
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
|
||||
const int32_t n_tokens = std::min(params.n_batch, batch.n_tokens - i);
|
||||
llama_batch batch_view = {
|
||||
n_tokens,
|
||||
batch.token + i,
|
||||
|
@ -1227,7 +1235,7 @@ struct server_context {
|
|||
{"mirostat_eta", slot.sparams.mirostat_eta},
|
||||
{"penalize_nl", slot.sparams.penalize_nl},
|
||||
{"stop", slot.params.antiprompt},
|
||||
{"n_predict", slot.params.n_predict},
|
||||
{"n_predict", slot.params.n_predict}, // TODO: fix duplicate key n_predict
|
||||
{"n_keep", params.n_keep},
|
||||
{"ignore_eos", ignore_eos},
|
||||
{"stream", slot.params.stream},
|
||||
|
@ -1739,7 +1747,8 @@ struct server_context {
|
|||
}
|
||||
|
||||
// process in chunks of params.n_batch
|
||||
int32_t n_batch = params.n_batch;
|
||||
int32_t n_batch = llama_n_batch(ctx);
|
||||
int32_t n_ubatch = llama_n_ubatch(ctx);
|
||||
|
||||
// next, batch any pending prompts without exceeding n_batch
|
||||
if (params.cont_batching || batch.n_tokens == 0) {
|
||||
|
@ -1812,7 +1821,7 @@ struct server_context {
|
|||
|
||||
if (slot.embedding) {
|
||||
// this prompt is too large to process - discard it
|
||||
if (slot.n_prompt_tokens > n_batch) {
|
||||
if (slot.n_prompt_tokens > n_ubatch) {
|
||||
slot.state = SLOT_STATE_PROCESSING;
|
||||
slot.command = SLOT_COMMAND_NONE;
|
||||
slot.release();
|
||||
|
@ -2158,7 +2167,8 @@ static void server_print_usage(const char * argv0, const gpt_params & params, co
|
|||
printf(" --pooling {none,mean,cls} pooling type for embeddings, use model default if unspecified\n");
|
||||
printf(" -dt N, --defrag-thold N\n");
|
||||
printf(" KV cache defragmentation threshold (default: %.1f, < 0 - disabled)\n", params.defrag_thold);
|
||||
printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
|
||||
printf(" -b N, --batch-size N logical maximum batch size (default: %d)\n", params.n_batch);
|
||||
printf(" -ub N, --ubatch-size N physical maximum batch size (default: %d)\n", params.n_ubatch);
|
||||
printf(" --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n");
|
||||
printf(" not recommended: doubles context memory required and no measurable increase in quality\n");
|
||||
if (llama_supports_mlock()) {
|
||||
|
@ -2425,6 +2435,12 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams,
|
|||
break;
|
||||
}
|
||||
params.n_batch = std::stoi(argv[i]);
|
||||
} else if (arg == "-ub" || arg == "--ubatch-size") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params.n_ubatch = std::stoi(argv[i]);
|
||||
} else if (arg == "--gpu-layers" || arg == "-ngl" || arg == "--n-gpu-layers") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
|
@ -2764,6 +2780,7 @@ int main(int argc, char ** argv) {
|
|||
res.set_header("Access-Control-Allow-Credentials", "true");
|
||||
res.set_header("Access-Control-Allow-Methods", "POST");
|
||||
res.set_header("Access-Control-Allow-Headers", "*");
|
||||
return res.set_content("", "application/json; charset=utf-8");
|
||||
});
|
||||
|
||||
svr->set_logger(log_server_request);
|
||||
|
@ -3372,44 +3389,37 @@ int main(int argc, char ** argv) {
|
|||
const json body = json::parse(req.body);
|
||||
bool is_openai = false;
|
||||
|
||||
// an input prompt can string or a list of tokens (integer)
|
||||
std::vector<json> prompts;
|
||||
// an input prompt can be a string or a list of tokens (integer)
|
||||
json prompt;
|
||||
if (body.count("input") != 0) {
|
||||
is_openai = true;
|
||||
if (body["input"].is_array()) {
|
||||
// support multiple prompts
|
||||
for (const json & elem : body["input"]) {
|
||||
prompts.push_back(elem);
|
||||
}
|
||||
} else {
|
||||
// single input prompt
|
||||
prompts.push_back(body["input"]);
|
||||
}
|
||||
prompt = body["input"];
|
||||
} else if (body.count("content") != 0) {
|
||||
// only support single prompt here
|
||||
std::string content = body["content"];
|
||||
prompts.push_back(content);
|
||||
// with "content", we only support single prompt
|
||||
prompt = std::vector<std::string>{body["content"]};
|
||||
} else {
|
||||
res_error(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
}
|
||||
|
||||
// process all prompts
|
||||
json responses = json::array();
|
||||
for (auto & prompt : prompts) {
|
||||
// TODO @ngxson : maybe support multitask for this endpoint?
|
||||
// create and queue the task
|
||||
// create and queue the task
|
||||
json responses;
|
||||
{
|
||||
const int id_task = ctx_server.queue_tasks.get_new_id();
|
||||
|
||||
ctx_server.queue_results.add_waiting_task_id(id_task);
|
||||
ctx_server.request_completion(id_task, -1, { {"prompt", prompt}, { "n_predict", 0}}, false, true);
|
||||
ctx_server.request_completion(id_task, -1, {{"prompt", prompt}}, false, true);
|
||||
|
||||
// get the result
|
||||
server_task_result result = ctx_server.queue_results.recv(id_task);
|
||||
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
||||
if (!result.error) {
|
||||
// append to the responses
|
||||
responses.push_back(result.data);
|
||||
if (result.data.count("results")) {
|
||||
// result for multi-task
|
||||
responses = result.data["results"];
|
||||
} else {
|
||||
// result for single task
|
||||
responses = std::vector<json>{result.data};
|
||||
}
|
||||
} else {
|
||||
// error received, ignore everything else
|
||||
res_error(res, result.data);
|
||||
|
@ -3418,24 +3428,19 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
// write JSON response
|
||||
json root;
|
||||
if (is_openai) {
|
||||
json res_oai = json::array();
|
||||
int i = 0;
|
||||
for (auto & elem : responses) {
|
||||
res_oai.push_back(json{
|
||||
{"embedding", json_value(elem, "embedding", json::array())},
|
||||
{"index", i++},
|
||||
{"object", "embedding"}
|
||||
});
|
||||
}
|
||||
root = format_embeddings_response_oaicompat(body, res_oai);
|
||||
} else {
|
||||
root = responses[0];
|
||||
}
|
||||
json root = is_openai
|
||||
? format_embeddings_response_oaicompat(body, responses)
|
||||
: responses[0];
|
||||
return res.set_content(root.dump(), "application/json; charset=utf-8");
|
||||
};
|
||||
|
||||
auto handle_static_file = [](unsigned char * content, size_t len, const char * mime_type) {
|
||||
return [content, len, mime_type](const httplib::Request &, httplib::Response & res) {
|
||||
res.set_content(reinterpret_cast<const char*>(content), len, mime_type);
|
||||
return false;
|
||||
};
|
||||
};
|
||||
|
||||
//
|
||||
// Router
|
||||
//
|
||||
|
@ -3447,17 +3452,6 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
// using embedded static files
|
||||
auto handle_static_file = [](unsigned char * content, size_t len, const char * mime_type) {
|
||||
return [content, len, mime_type](const httplib::Request &, httplib::Response & res) {
|
||||
res.set_content(reinterpret_cast<const char*>(content), len, mime_type);
|
||||
return false;
|
||||
};
|
||||
};
|
||||
|
||||
svr->Options(R"(/.*)", [](const httplib::Request &, httplib::Response & res) {
|
||||
// TODO @ngxson : I have no idea what it is... maybe this is redundant?
|
||||
return res.set_content("", "application/json; charset=utf-8");
|
||||
});
|
||||
svr->Get("/", handle_static_file(index_html, index_html_len, "text/html; charset=utf-8"));
|
||||
svr->Get("/index.js", handle_static_file(index_js, index_js_len, "text/javascript; charset=utf-8"));
|
||||
svr->Get("/completion.js", handle_static_file(completion_js, completion_js_len, "text/javascript; charset=utf-8"));
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue