mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-05-17 04:09:19 +00:00
rpc : use graph uid instead of graph cache (#22701)
Store the last graph uid and compare against it to determine if the same graph is being computed.
This commit is contained in:
parent
2635ac76e8
commit
d5003b6e4d
1 changed files with 7 additions and 31 deletions
|
|
@ -207,35 +207,11 @@ struct ggml_backend_rpc_buffer_type_context {
|
|||
size_t max_size;
|
||||
};
|
||||
|
||||
struct graph_cache {
|
||||
|
||||
bool is_cached(const ggml_cgraph * cgraph) {
|
||||
if ((int)last_graph.size() != cgraph->n_nodes) {
|
||||
return false;
|
||||
}
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
if (memcmp(&last_graph[i], cgraph->nodes[i], sizeof(ggml_tensor)) != 0) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void add(const ggml_cgraph * cgraph) {
|
||||
last_graph.resize(cgraph->n_nodes);
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
memcpy(&last_graph[i], cgraph->nodes[i], sizeof(ggml_tensor));
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<ggml_tensor> last_graph;
|
||||
};
|
||||
|
||||
struct ggml_backend_rpc_context {
|
||||
std::string endpoint;
|
||||
uint32_t device;
|
||||
std::string name;
|
||||
graph_cache gc;
|
||||
uint64_t last_graph_uid;
|
||||
};
|
||||
|
||||
struct ggml_backend_rpc_buffer_context {
|
||||
|
|
@ -717,7 +693,7 @@ static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, g
|
|||
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
|
||||
|
||||
GGML_ASSERT(cgraph->n_nodes > 0);
|
||||
bool reuse = rpc_ctx->gc.is_cached(cgraph);
|
||||
bool reuse = cgraph->uid != 0 && rpc_ctx->last_graph_uid == cgraph->uid;
|
||||
if (reuse) {
|
||||
rpc_msg_graph_recompute_req request;
|
||||
request.device = rpc_ctx->device;
|
||||
|
|
@ -725,7 +701,7 @@ static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, g
|
|||
bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_RECOMPUTE, &request, sizeof(request));
|
||||
RPC_STATUS_ASSERT(status);
|
||||
} else {
|
||||
rpc_ctx->gc.add(cgraph);
|
||||
rpc_ctx->last_graph_uid = cgraph->uid;
|
||||
std::vector<uint8_t> input;
|
||||
serialize_graph(rpc_ctx->device, cgraph, input);
|
||||
auto sock = get_socket(rpc_ctx->endpoint);
|
||||
|
|
@ -791,10 +767,10 @@ ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint, u
|
|||
ggml_backend_t ggml_backend_rpc_init(const char * endpoint, uint32_t device) {
|
||||
std::string dev_name = "RPC" + std::to_string(device) + "[" + std::string(endpoint) + "]";
|
||||
ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
|
||||
/* .endpoint = */ endpoint,
|
||||
/* .device = */ device,
|
||||
/* .name = */ dev_name,
|
||||
/* .gc = */ {},
|
||||
/* .endpoint = */ endpoint,
|
||||
/* .device = */ device,
|
||||
/* .name = */ dev_name,
|
||||
/* .last_graph_uid = */ 0,
|
||||
};
|
||||
auto reg = ggml_backend_rpc_add_server(endpoint);
|
||||
ggml_backend_t backend = new ggml_backend {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue