mirror of
https://github.com/Lizonghang/prima.cpp.git
synced 2025-09-07 03:39:04 +00:00
fix cuda support
This commit is contained in:
parent
9d6a6845ac
commit
976a4c3534
1 changed files with 36 additions and 55 deletions
|
@ -17317,8 +17317,8 @@ static void llama_graph_compute(
|
|||
}
|
||||
|
||||
struct input_tensors {
|
||||
ggml_tensor * sub_gf_out = nullptr;
|
||||
ggml_tensor * inp_pos = nullptr;
|
||||
ggml_tensor * sub_gf_out;
|
||||
ggml_tensor * inp_pos;
|
||||
};
|
||||
|
||||
struct sync_meta {
|
||||
|
@ -17340,7 +17340,7 @@ static void llama_send_meta(zmq::socket_t & socket, struct sync_meta * meta) {
|
|||
}
|
||||
}
|
||||
|
||||
static void llama_send_tensors(zmq::socket_t & socket, struct input_tensors * tensors) {
|
||||
static void llama_send_tensors(zmq::socket_t & socket, struct llama_ubatch * ubatch, struct input_tensors * tensors) {
|
||||
try {
|
||||
std::vector<zmq::message_t> send_msgs;
|
||||
size_t buf_size = 0;
|
||||
|
@ -17348,13 +17348,13 @@ static void llama_send_tensors(zmq::socket_t & socket, struct input_tensors * te
|
|||
send_msgs.emplace_back("sub_gf_out", strlen("sub_gf_out"));
|
||||
send_msgs.emplace_back(tensors->sub_gf_out->ne, sizeof(tensors->sub_gf_out->ne));
|
||||
buf_size = tensors->sub_gf_out->ne[0] * tensors->sub_gf_out->ne[1] * sizeof(float);
|
||||
send_msgs.emplace_back(tensors->sub_gf_out->data, buf_size);
|
||||
send_msgs.emplace_back(ubatch->backend_embd, buf_size);
|
||||
|
||||
if (tensors->inp_pos) {
|
||||
send_msgs.emplace_back("inp_pos", strlen("inp_pos"));
|
||||
send_msgs.emplace_back(tensors->inp_pos->ne, sizeof(tensors->inp_pos->ne[0]));
|
||||
buf_size = tensors->inp_pos->ne[0] * sizeof(int32_t);
|
||||
send_msgs.emplace_back(tensors->inp_pos->data, buf_size);
|
||||
send_msgs.emplace_back(ubatch->pos, buf_size);
|
||||
}
|
||||
|
||||
zmq::send_multipart(socket, send_msgs);
|
||||
|
@ -17385,7 +17385,7 @@ static int llama_recv_meta(zmq::socket_t & socket, struct sync_meta * meta) {
|
|||
return 0;
|
||||
}
|
||||
|
||||
static void llama_recv_tensors(zmq::socket_t & socket, input_tensors * tensors) {
|
||||
static void llama_recv_tensors(zmq::socket_t & socket, struct llama_ubatch * ubatch, struct llama_context * lctx, const bool is_out_embd=false) {
|
||||
std::vector<zmq::message_t> recv_msgs;
|
||||
if (!zmq::recv_multipart(socket, std::back_inserter(recv_msgs))) {
|
||||
LLAMA_LOG_INFO("Failed to receive tensor data.\n");
|
||||
|
@ -17396,19 +17396,15 @@ static void llama_recv_tensors(zmq::socket_t & socket, input_tensors * tensors)
|
|||
zmq::message_t &dims_msg = recv_msgs[i + 1];
|
||||
zmq::message_t &data_msg = recv_msgs[i + 2];
|
||||
|
||||
if (key == "sub_gf_out" && tensors->sub_gf_out) {
|
||||
int64_t * dims = static_cast<int64_t*>(dims_msg.data());
|
||||
size_t buf_size = dims[0] * dims[1] * sizeof(float);
|
||||
GGML_ASSERT(dims[0] == tensors->sub_gf_out->ne[0]);
|
||||
GGML_ASSERT(dims[1] == tensors->sub_gf_out->ne[1]);
|
||||
GGML_ASSERT(data_msg.size() == buf_size);
|
||||
std::memcpy(tensors->sub_gf_out->data, data_msg.data(), buf_size);
|
||||
} else if (key == "inp_pos" && tensors->inp_pos) {
|
||||
int64_t * dims = static_cast<int64_t*>(dims_msg.data());
|
||||
if (key == "sub_gf_out") {
|
||||
int64_t * dims = static_cast<int64_t *>(dims_msg.data());
|
||||
size_t buf_size = dims[0] * dims[1] * sizeof(float);
|
||||
float * batch_embd = is_out_embd ? ubatch->out_embd : ubatch->backend_embd;
|
||||
std::memcpy(batch_embd, data_msg.data(), buf_size);
|
||||
} else if (key == "inp_pos") {
|
||||
int64_t * dims = static_cast<int64_t *>(dims_msg.data());
|
||||
size_t buf_size = dims[0] * sizeof(int32_t);
|
||||
GGML_ASSERT(dims[0] == tensors->inp_pos->ne[0]);
|
||||
GGML_ASSERT(data_msg.size() == buf_size);
|
||||
std::memcpy(tensors->inp_pos->data, data_msg.data(), buf_size);
|
||||
std::memcpy(ubatch->pos, data_msg.data(), buf_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -17725,29 +17721,14 @@ static int llama_decode_internal(
|
|||
for (size_t i = 0; i < (size_t)gf.size(); ++i) {
|
||||
sub_gf = gf[i];
|
||||
|
||||
// receive data from other nodes
|
||||
if (n_world > 1 && !(my_rank == 0 && i == 0) && !(my_rank == 0 && is_last_l)) {
|
||||
// receive data from previous nodes
|
||||
input_tensors tensors;
|
||||
const bool is_out_embd = my_rank == 0 && i == (size_t)gf.size() - 1;
|
||||
tensors.sub_gf_out = is_out_embd ? lctx.out_embd : lctx.backend_embd;
|
||||
tensors.inp_pos = lctx.inp_pos;
|
||||
llama_recv_tensors(*lctx.recv_socket, &tensors);
|
||||
|
||||
is_last_l = my_rank == 0 && i == (size_t)gf.size() - 1;
|
||||
size_t buf_size = tensors.sub_gf_out->ne[0] * tensors.sub_gf_out->ne[1] * ggml_element_size(tensors.sub_gf_out);
|
||||
if (!is_last_l) {
|
||||
memcpy(ubatch.backend_embd, tensors.sub_gf_out->data, buf_size);
|
||||
} else {
|
||||
memcpy(ubatch.out_embd, tensors.sub_gf_out->data, buf_size);
|
||||
}
|
||||
if (my_rank != 0 && i == 0) {
|
||||
buf_size = tensors.inp_pos->ne[0] * ggml_element_size(tensors.inp_pos);
|
||||
memcpy(ubatch.pos, tensors.inp_pos->data, buf_size);
|
||||
}
|
||||
llama_recv_tensors(*lctx.recv_socket, &ubatch, &lctx, is_out_embd);
|
||||
}
|
||||
|
||||
if (i > 0) {
|
||||
// ensure ggml_backend_tensor_get_async of the previous subgraph has finished
|
||||
// ensure ggml_backend_tensor_get_async of the previous subgraph has finished
|
||||
if (i > 0 && (n_world == 1 || (my_rank == 0 && is_last_l))) {
|
||||
ggml_backend_sched_synchronize(lctx.sched[i - 1]);
|
||||
}
|
||||
|
||||
|
@ -17772,28 +17753,28 @@ static int llama_decode_internal(
|
|||
is_last_l = (cur_l == static_cast<int>(n_layer) - 1);
|
||||
}
|
||||
|
||||
// send the result to the next node (or the master)
|
||||
float * embd_buf;
|
||||
if (n_world == 1 || (my_rank == 0 && is_last_l)) {
|
||||
size_t buf_size = sub_gf_out->ne[0] * sub_gf_out->ne[1] * sizeof(float);
|
||||
float * embd_buf = is_last_l ? ubatch.out_embd : ubatch.backend_embd;
|
||||
ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(lctx.sched[i], sub_gf_out);
|
||||
|
||||
GGML_ASSERT(buf_size <= ggml_nbytes(sub_gf_out));
|
||||
GGML_ASSERT(backend != nullptr);
|
||||
GGML_ASSERT(embd_buf != nullptr);
|
||||
|
||||
ggml_backend_tensor_get_async(backend, sub_gf_out, embd_buf, 0, buf_size);
|
||||
embd_buf = is_last_l ? ubatch.out_embd : ubatch.backend_embd;
|
||||
} else {
|
||||
input_tensors tensors;
|
||||
tensors.sub_gf_out = sub_gf_out;
|
||||
if (i == 0 && !is_last_l && my_rank != n_world - 1) {
|
||||
tensors.inp_pos = lctx.inp_pos;
|
||||
const size_t buf_size = ubatch.n_tokens * ggml_element_size(tensors.inp_pos);
|
||||
memcpy(tensors.inp_pos->data, ubatch.pos, buf_size);
|
||||
}
|
||||
embd_buf = ubatch.backend_embd;
|
||||
}
|
||||
GGML_ASSERT(embd_buf != nullptr);
|
||||
|
||||
// copy device data to cpu memory
|
||||
size_t buf_size = sub_gf_out->ne[0] * sub_gf_out->ne[1] * sizeof(float);
|
||||
ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(lctx.sched[i], sub_gf_out);
|
||||
GGML_ASSERT(buf_size <= ggml_nbytes(sub_gf_out));
|
||||
GGML_ASSERT(backend != nullptr);
|
||||
ggml_backend_tensor_get_async(backend, sub_gf_out, embd_buf, 0, buf_size);
|
||||
|
||||
// send the result to the next node or the master
|
||||
if (!(n_world == 1 || (my_rank == 0 && is_last_l))) {
|
||||
struct input_tensors tensors = {sub_gf_out, lctx.inp_pos};
|
||||
const bool is_to_master = my_rank != 0 && is_last_l;
|
||||
zmq::socket_t * s = is_to_master ? lctx.master_socket : lctx.send_socket;
|
||||
llama_send_tensors(*s, &tensors);
|
||||
ggml_backend_sched_synchronize(lctx.sched[i]);
|
||||
llama_send_tensors(*s, &ubatch, &tensors);
|
||||
}
|
||||
|
||||
// overlap memory scheduling with other nodes' communication and computing
|
||||
|
|
Loading…
Add table
Reference in a new issue