fix cuda support

This commit is contained in:
Lizonghang 2024-11-04 11:00:01 +04:00
parent 9d6a6845ac
commit 976a4c3534

View file

@ -17317,8 +17317,8 @@ static void llama_graph_compute(
}
struct input_tensors {
ggml_tensor * sub_gf_out = nullptr;
ggml_tensor * inp_pos = nullptr;
ggml_tensor * sub_gf_out;
ggml_tensor * inp_pos;
};
struct sync_meta {
@ -17340,7 +17340,7 @@ static void llama_send_meta(zmq::socket_t & socket, struct sync_meta * meta) {
}
}
static void llama_send_tensors(zmq::socket_t & socket, struct input_tensors * tensors) {
static void llama_send_tensors(zmq::socket_t & socket, struct llama_ubatch * ubatch, struct input_tensors * tensors) {
try {
std::vector<zmq::message_t> send_msgs;
size_t buf_size = 0;
@ -17348,13 +17348,13 @@ static void llama_send_tensors(zmq::socket_t & socket, struct input_tensors * te
send_msgs.emplace_back("sub_gf_out", strlen("sub_gf_out"));
send_msgs.emplace_back(tensors->sub_gf_out->ne, sizeof(tensors->sub_gf_out->ne));
buf_size = tensors->sub_gf_out->ne[0] * tensors->sub_gf_out->ne[1] * sizeof(float);
send_msgs.emplace_back(tensors->sub_gf_out->data, buf_size);
send_msgs.emplace_back(ubatch->backend_embd, buf_size);
if (tensors->inp_pos) {
send_msgs.emplace_back("inp_pos", strlen("inp_pos"));
send_msgs.emplace_back(tensors->inp_pos->ne, sizeof(tensors->inp_pos->ne[0]));
buf_size = tensors->inp_pos->ne[0] * sizeof(int32_t);
send_msgs.emplace_back(tensors->inp_pos->data, buf_size);
send_msgs.emplace_back(ubatch->pos, buf_size);
}
zmq::send_multipart(socket, send_msgs);
@ -17385,7 +17385,7 @@ static int llama_recv_meta(zmq::socket_t & socket, struct sync_meta * meta) {
return 0;
}
static void llama_recv_tensors(zmq::socket_t & socket, input_tensors * tensors) {
static void llama_recv_tensors(zmq::socket_t & socket, struct llama_ubatch * ubatch, struct llama_context * lctx, const bool is_out_embd=false) {
std::vector<zmq::message_t> recv_msgs;
if (!zmq::recv_multipart(socket, std::back_inserter(recv_msgs))) {
LLAMA_LOG_INFO("Failed to receive tensor data.\n");
@ -17396,19 +17396,15 @@ static void llama_recv_tensors(zmq::socket_t & socket, input_tensors * tensors)
zmq::message_t &dims_msg = recv_msgs[i + 1];
zmq::message_t &data_msg = recv_msgs[i + 2];
if (key == "sub_gf_out" && tensors->sub_gf_out) {
int64_t * dims = static_cast<int64_t*>(dims_msg.data());
size_t buf_size = dims[0] * dims[1] * sizeof(float);
GGML_ASSERT(dims[0] == tensors->sub_gf_out->ne[0]);
GGML_ASSERT(dims[1] == tensors->sub_gf_out->ne[1]);
GGML_ASSERT(data_msg.size() == buf_size);
std::memcpy(tensors->sub_gf_out->data, data_msg.data(), buf_size);
} else if (key == "inp_pos" && tensors->inp_pos) {
int64_t * dims = static_cast<int64_t*>(dims_msg.data());
if (key == "sub_gf_out") {
int64_t * dims = static_cast<int64_t *>(dims_msg.data());
size_t buf_size = dims[0] * dims[1] * sizeof(float);
float * batch_embd = is_out_embd ? ubatch->out_embd : ubatch->backend_embd;
std::memcpy(batch_embd, data_msg.data(), buf_size);
} else if (key == "inp_pos") {
int64_t * dims = static_cast<int64_t *>(dims_msg.data());
size_t buf_size = dims[0] * sizeof(int32_t);
GGML_ASSERT(dims[0] == tensors->inp_pos->ne[0]);
GGML_ASSERT(data_msg.size() == buf_size);
std::memcpy(tensors->inp_pos->data, data_msg.data(), buf_size);
std::memcpy(ubatch->pos, data_msg.data(), buf_size);
}
}
}
@ -17725,29 +17721,14 @@ static int llama_decode_internal(
for (size_t i = 0; i < (size_t)gf.size(); ++i) {
sub_gf = gf[i];
// receive data from other nodes
if (n_world > 1 && !(my_rank == 0 && i == 0) && !(my_rank == 0 && is_last_l)) {
// receive data from previous nodes
input_tensors tensors;
const bool is_out_embd = my_rank == 0 && i == (size_t)gf.size() - 1;
tensors.sub_gf_out = is_out_embd ? lctx.out_embd : lctx.backend_embd;
tensors.inp_pos = lctx.inp_pos;
llama_recv_tensors(*lctx.recv_socket, &tensors);
is_last_l = my_rank == 0 && i == (size_t)gf.size() - 1;
size_t buf_size = tensors.sub_gf_out->ne[0] * tensors.sub_gf_out->ne[1] * ggml_element_size(tensors.sub_gf_out);
if (!is_last_l) {
memcpy(ubatch.backend_embd, tensors.sub_gf_out->data, buf_size);
} else {
memcpy(ubatch.out_embd, tensors.sub_gf_out->data, buf_size);
}
if (my_rank != 0 && i == 0) {
buf_size = tensors.inp_pos->ne[0] * ggml_element_size(tensors.inp_pos);
memcpy(ubatch.pos, tensors.inp_pos->data, buf_size);
}
llama_recv_tensors(*lctx.recv_socket, &ubatch, &lctx, is_out_embd);
}
if (i > 0) {
// ensure ggml_backend_tensor_get_async of the previous subgraph has finished
// ensure ggml_backend_tensor_get_async of the previous subgraph has finished
if (i > 0 && (n_world == 1 || (my_rank == 0 && is_last_l))) {
ggml_backend_sched_synchronize(lctx.sched[i - 1]);
}
@ -17772,28 +17753,28 @@ static int llama_decode_internal(
is_last_l = (cur_l == static_cast<int>(n_layer) - 1);
}
// send the result to the next node (or the master)
float * embd_buf;
if (n_world == 1 || (my_rank == 0 && is_last_l)) {
size_t buf_size = sub_gf_out->ne[0] * sub_gf_out->ne[1] * sizeof(float);
float * embd_buf = is_last_l ? ubatch.out_embd : ubatch.backend_embd;
ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(lctx.sched[i], sub_gf_out);
GGML_ASSERT(buf_size <= ggml_nbytes(sub_gf_out));
GGML_ASSERT(backend != nullptr);
GGML_ASSERT(embd_buf != nullptr);
ggml_backend_tensor_get_async(backend, sub_gf_out, embd_buf, 0, buf_size);
embd_buf = is_last_l ? ubatch.out_embd : ubatch.backend_embd;
} else {
input_tensors tensors;
tensors.sub_gf_out = sub_gf_out;
if (i == 0 && !is_last_l && my_rank != n_world - 1) {
tensors.inp_pos = lctx.inp_pos;
const size_t buf_size = ubatch.n_tokens * ggml_element_size(tensors.inp_pos);
memcpy(tensors.inp_pos->data, ubatch.pos, buf_size);
}
embd_buf = ubatch.backend_embd;
}
GGML_ASSERT(embd_buf != nullptr);
// copy device data to cpu memory
size_t buf_size = sub_gf_out->ne[0] * sub_gf_out->ne[1] * sizeof(float);
ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(lctx.sched[i], sub_gf_out);
GGML_ASSERT(buf_size <= ggml_nbytes(sub_gf_out));
GGML_ASSERT(backend != nullptr);
ggml_backend_tensor_get_async(backend, sub_gf_out, embd_buf, 0, buf_size);
// send the result to the next node or the master
if (!(n_world == 1 || (my_rank == 0 && is_last_l))) {
struct input_tensors tensors = {sub_gf_out, lctx.inp_pos};
const bool is_to_master = my_rank != 0 && is_last_l;
zmq::socket_t * s = is_to_master ? lctx.master_socket : lctx.send_socket;
llama_send_tensors(*s, &tensors);
ggml_backend_sched_synchronize(lctx.sched[i]);
llama_send_tensors(*s, &ubatch, &tensors);
}
// overlap memory scheduling with other nodes' communication and computing