diff --git a/common/common.cpp b/common/common.cpp index a21146b7..699a2dd7 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -28,6 +28,7 @@ #include #include #include +#include #if defined(__APPLE__) && defined(__MACH__) #include @@ -1806,10 +1807,20 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) { llama_update_context_with_rankworld(lctx, update_rank, update_n_world, worker_rank, n_worker); if(node_type == NodeType::NODE_TYPE_FORWARDER){ - //just foward - while (true) { - llama_forward_messages(lctx); - } + //just forward + std::atomic should_exit{false}; + auto t = std::thread([lctx, &should_exit]() { + while(!should_exit) { + llama_forward_messages(lctx); + } + }); + char * stop_signal = nullptr; + llama_free_sockets(lctx, &stop_signal); // this will block until receive stop signal + + should_exit = true; + t.join(); + + exit(0); } // update n_layer_window and n_gpu_layers diff --git a/src/llama.cpp b/src/llama.cpp index dd0540cf..4880d029 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -20776,9 +20776,13 @@ int llama_rebuild_topo(llama_context * ctx, int llama_forward_messages(llama_context *ctx) { zmq::message_t message; int more = true; - + int timeout_ms = 10; + ctx->recv_socket->setsockopt(ZMQ_RCVTIMEO, &timeout_ms, sizeof(timeout_ms)); while (more) { - ctx->recv_socket->recv(message, zmq::recv_flags::none); + auto recv_result = ctx->recv_socket->recv(message, zmq::recv_flags::none); + if (!recv_result) { + return -1; + } size_t more_size = sizeof(more); ctx->recv_socket->getsockopt(ZMQ_RCVMORE, &more, &more_size);