server: fix bugs

This commit is contained in:
Li, Zonghang 2025-07-13 13:42:24 +08:00
parent 0cf87c8837
commit b019a707b8
3 changed files with 13 additions and 5 deletions

View file

@ -673,6 +673,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
format("size of the prompt context (default: %d, 0 = loaded from model)", params.n_ctx), format("size of the prompt context (default: %d, 0 = loaded from model)", params.n_ctx),
[](gpt_params & params, int value) { [](gpt_params & params, int value) {
params.n_ctx = value; params.n_ctx = value;
params.speculative.n_ctx = value;
} }
).set_env("LLAMA_ARG_CTX_SIZE")); ).set_env("LLAMA_ARG_CTX_SIZE"));
add_opt(llama_arg( add_opt(llama_arg(

View file

@ -760,6 +760,8 @@ struct server_context {
llama_free (llama_init_dft.context); llama_free (llama_init_dft.context);
llama_free_model(llama_init_dft.model); llama_free_model(llama_init_dft.model);
model_dft = nullptr;
return false; return false;
} }
@ -3566,6 +3568,8 @@ int main(int argc, char ** argv) {
LOG_INF("%s: loading model\n", __func__); LOG_INF("%s: loading model\n", __func__);
if (!ctx_server.load_model(params)) { if (!ctx_server.load_model(params)) {
char * stop_signal = nullptr;
llama_free_sockets(ctx_server.ctx, &stop_signal);
clean_up(); clean_up();
t.join(); t.join();
LOG_ERR("%s: exiting due to model loading error\n", __func__); LOG_ERR("%s: exiting due to model loading error\n", __func__);

View file

@ -17878,7 +17878,7 @@ static void llama_send_meta(zmq::socket_t & socket, struct sync_meta * meta, boo
if (meta->pos != nullptr) { if (meta->pos != nullptr) {
send_msgs.emplace_back("pos", strlen("pos")); send_msgs.emplace_back("pos", strlen("pos"));
send_msgs.emplace_back(meta->pos, meta->n_ctx * sizeof(llama_pos)); send_msgs.emplace_back(meta->pos, meta->n_tokens * sizeof(llama_pos));
} }
if (meta->n_seq_id != nullptr) { if (meta->n_seq_id != nullptr) {
@ -17986,8 +17986,8 @@ static int llama_recv_meta(zmq::socket_t & socket, struct sync_meta * meta) {
} }
if (key == "pos") { if (key == "pos") {
meta->pos = (llama_pos *) malloc(meta->n_ctx * sizeof(llama_pos)); meta->pos = (llama_pos *) malloc(meta->n_tokens * sizeof(llama_pos));
std::memcpy(meta->pos, data_msg.data(), meta->n_ctx * sizeof(llama_pos)); std::memcpy(meta->pos, data_msg.data(), meta->n_tokens * sizeof(llama_pos));
} }
if (key == "n_seq_id") { if (key == "n_seq_id") {
@ -18304,8 +18304,8 @@ static int llama_decode_internal(
if (meta.n_tokens > 0) { if (meta.n_tokens > 0) {
batch_all.n_tokens = meta.n_tokens; batch_all.n_tokens = meta.n_tokens;
if (meta.pos != nullptr) { if (meta.pos != nullptr) {
batch_all.pos = (llama_pos *) malloc(meta.n_ctx * sizeof(llama_pos)); batch_all.pos = (llama_pos *) malloc(meta.n_tokens * sizeof(llama_pos));
std::memcpy(batch_all.pos, meta.pos, meta.n_ctx * sizeof(llama_pos)); std::memcpy(batch_all.pos, meta.pos, meta.n_tokens * sizeof(llama_pos));
} }
if (meta.n_seq_id != nullptr) { if (meta.n_seq_id != nullptr) {
batch_all.n_seq_id = (int32_t *) malloc(meta.n_tokens * sizeof(int32_t)); batch_all.n_seq_id = (int32_t *) malloc(meta.n_tokens * sizeof(int32_t));
@ -22089,6 +22089,9 @@ void llama_model_compute_buf_size(
// this value may vary by GPU and CUDA version, but it's lower than 400 MiB in most cases, // this value may vary by GPU and CUDA version, but it's lower than 400 MiB in most cases,
// another 300 MiB is used to prevent accidental OOM. // another 300 MiB is used to prevent accidental OOM.
*gpu_buf += 700 * 1024 * 1024; *gpu_buf += 700 * 1024 * 1024;
} else if (backend == BACKEND_METAL) {
// 300 MiB is used to prevent accidental OOM, e.g., automatic quantization conversion.
*gpu_buf += 300 * 1024 * 1024;
} }
} }