rpc : keep last_graph_uid in the device context (#23273)

With the introduction of MTP we can have multiple compute contexts for
the same RPC device. In this case last_graph_uid is not updated properly
when contexts are being switched. This patch fixes this by moving
last_graph_uid to the device context, making sure it is always updated.

closes: #23242
This commit is contained in:
Radoslav Gerganov 2026-05-19 09:42:36 +03:00 committed by GitHub
parent 9a532ae4ba
commit c3e9ade6dd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -199,6 +199,14 @@ static ggml_guid_t ggml_backend_rpc_guid() {
return &guid;
}
struct ggml_backend_rpc_device_context {
std::string endpoint;
uint32_t device;
std::string name;
std::string description;
uint64_t last_graph_uid;
};
struct ggml_backend_rpc_buffer_type_context {
std::string endpoint;
uint32_t device;
@ -211,7 +219,6 @@ struct ggml_backend_rpc_context {
std::string endpoint;
uint32_t device;
std::string name;
uint64_t last_graph_uid;
};
struct ggml_backend_rpc_buffer_context {
@ -691,9 +698,11 @@ static void serialize_graph(uint32_t device, const ggml_cgraph * cgraph, std::ve
static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
ggml_backend_dev_t rpc_dev = ggml_backend_get_device(backend);
ggml_backend_rpc_device_context * rpc_dev_ctx = (ggml_backend_rpc_device_context *)rpc_dev->context;
GGML_ASSERT(cgraph->n_nodes > 0);
bool reuse = cgraph->uid != 0 && rpc_ctx->last_graph_uid == cgraph->uid;
bool reuse = cgraph->uid != 0 && rpc_dev_ctx->last_graph_uid == cgraph->uid;
if (reuse) {
rpc_msg_graph_recompute_req request;
request.device = rpc_ctx->device;
@ -701,7 +710,7 @@ static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, g
bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_RECOMPUTE, &request, sizeof(request));
RPC_STATUS_ASSERT(status);
} else {
rpc_ctx->last_graph_uid = cgraph->uid;
rpc_dev_ctx->last_graph_uid = cgraph->uid;
std::vector<uint8_t> input;
serialize_graph(rpc_ctx->device, cgraph, input);
auto sock = get_socket(rpc_ctx->endpoint);
@ -770,7 +779,6 @@ ggml_backend_t ggml_backend_rpc_init(const char * endpoint, uint32_t device) {
/* .endpoint = */ endpoint,
/* .device = */ device,
/* .name = */ dev_name,
/* .last_graph_uid = */ 0,
};
auto reg = ggml_backend_rpc_add_server(endpoint);
ggml_backend_t backend = new ggml_backend {
@ -1757,15 +1765,6 @@ void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir
}
}
// device interface
struct ggml_backend_rpc_device_context {
std::string endpoint;
uint32_t device;
std::string name;
std::string description;
};
static const char * ggml_backend_rpc_device_get_name(ggml_backend_dev_t dev) {
ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
@ -1947,10 +1946,11 @@ ggml_backend_reg_t ggml_backend_rpc_add_server(const char * endpoint) {
std::string dev_name = "RPC" + std::to_string(dev_id);
std::string dev_desc = std::string(endpoint);
ggml_backend_rpc_device_context * dev_ctx = new ggml_backend_rpc_device_context {
/* .endpoint = */ endpoint,
/* .device = */ ind,
/* .name = */ dev_name,
/* .description = */ dev_desc
/* .endpoint = */ endpoint,
/* .device = */ ind,
/* .name = */ dev_name,
/* .description = */ dev_desc,
/* .last_graph_uid = */ 0,
};
ggml_backend_dev_t dev = new ggml_backend_device {