diff --git a/common/common.h b/common/common.h index 0a679213..71c6e435 100644 --- a/common/common.h +++ b/common/common.h @@ -547,9 +547,9 @@ llama_control_vector_data llama_control_vector_load(const std::vector #include @@ -473,6 +473,10 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & } static results_perplexity perplexity(llama_context * ctx, const gpt_params & params, const int32_t n_ctx) { + uint32_t my_rank = params.rank; + uint32_t n_world = params.n_world; + bool is_last_dev = (my_rank == n_world - 1); + if (params.ppl_stride > 0) { return perplexity_v2(ctx, params); } @@ -485,38 +489,74 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par const bool add_bos = llama_add_bos_token(llama_get_model(ctx)); GGML_ASSERT(!llama_add_eos_token(llama_get_model(ctx))); + // only last device store logits file std::ofstream logits_stream; - if (!params.logits_file.empty()) { - logits_stream.open(params.logits_file.c_str(), std::ios::binary); - if (!logits_stream.is_open()) { - LOG_ERR("%s: failed to open %s for writing\n", __func__, params.logits_file.c_str()); - return {}; + if (my_rank == n_world - 1){ + if (!params.logits_file.empty()) { + logits_stream.open(params.logits_file.c_str(), std::ios::binary); + if (!logits_stream.is_open()) { + LOG_ERR("%s: failed to open %s for writing\n", __func__, params.logits_file.c_str()); + return {}; + } + LOG_INF("%s: saving all logits to %s\n", __func__, params.logits_file.c_str()); + logits_stream.write("_logits_", 8); + logits_stream.write(reinterpret_cast(&n_ctx), sizeof(n_ctx)); } - LOG_INF("%s: saving all logits to %s\n", __func__, params.logits_file.c_str()); - logits_stream.write("_logits_", 8); - logits_stream.write(reinterpret_cast(&n_ctx), sizeof(n_ctx)); } - auto tim1 = std::chrono::high_resolution_clock::now(); - LOG_INF("%s: tokenizing the input ..\n", __func__); + std::vector tokens; + size_t tokens_size = 0; - std::vector tokens = ::llama_tokenize(ctx, params.prompt, true); + // maybe we need to try other solutions, such as direct communication of tokens between the head and tail nodes + if (my_rank == 0 || is_last_dev) { + auto tim1 = std::chrono::high_resolution_clock::now(); + LOG_INF("%s: rank %d tokenizing the input ..\n", __func__, my_rank); - auto tim2 = std::chrono::high_resolution_clock::now(); - LOG_INF("%s: tokenization took %g ms\n",__func__,1e-3*std::chrono::duration_cast(tim2-tim1).count()); + tokens = ::llama_tokenize(ctx, params.prompt, true); + tokens_size = tokens.size(); - if (int(tokens.size()) < 2*n_ctx) { - LOG_ERR("%s: you need at least %d tokens to evaluate perplexity with a context of %d\n",__func__,2*n_ctx, + auto tim2 = std::chrono::high_resolution_clock::now(); + LOG_INF("%s: rank %d tokenization took %g ms\n", __func__, my_rank, 1e-3*std::chrono::duration_cast(tim2-tim1).count()); + } + + { + if (n_world > 1) { + sync_meta meta; + if (my_rank == 0) { + meta.tokens_size = tokens_size; + llama_send_meta(ctx, &meta); + } else { + if (llama_recv_meta(ctx, &meta) == -1) { + LOG_ERR("%s: failed to receive tokens_size on rank %d\n", __func__, my_rank); + return { {}, -1.0, {}, {} }; + } + if (is_last_dev) { + GGML_ASSERT(tokens_size == meta.tokens_size && "Token size mismatch between rank 0 and last rank!"); + } else { + tokens_size = meta.tokens_size; + llama_send_meta(ctx, &meta); + } + } + } + } + LOG_INF("%s: rank %d synchronized tokens_size = %zu\n", __func__, my_rank, tokens_size); + + if (my_rank == 0) { + if (int(tokens.size()) < 2*n_ctx) { + LOG_ERR("%s: you need at least %d tokens to evaluate perplexity with a context of %d\n",__func__,2*n_ctx, n_ctx); - LOG_ERR("%s: the data file you provided tokenizes to only %zu tokens\n",__func__,tokens.size()); - return {std::move(tokens), 0., {}, {}}; + LOG_ERR("%s: the data file you provided tokenizes to only %zu tokens\n",__func__,tokens.size()); + return {std::move(tokens), 0., {}, {}}; + } } std::vector logit_history; - logit_history.resize(tokens.size()); - std::vector prob_history; - prob_history.resize(tokens.size()); + + if (is_last_dev) { + logit_history.resize(tokens_size); + prob_history.resize(tokens_size); + } const int n_chunk_max = tokens.size() / n_ctx; @@ -537,23 +577,34 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par llama_batch batch = llama_batch_init(std::min(n_batch, n_ctx*n_seq), 0, 1); std::vector logits; - if (num_batches > 1) { - logits.reserve((size_t)n_ctx * n_vocab); + + if(is_last_dev){ + if (num_batches > 1) { + logits.reserve((size_t)n_ctx * n_vocab); + } } LOG_INF("%s: calculating perplexity over %d chunks, n_ctx=%d, batch_size=%d, n_seq=%d\n", __func__, n_chunk, n_ctx, n_batch, n_seq); std::vector workers(std::thread::hardware_concurrency() - 1); - std::vector log_probs; - if (!params.logits_file.empty()) { - logits_stream.write((const char *)&n_vocab, sizeof(n_vocab)); - logits_stream.write((const char *)&n_chunk, sizeof(n_chunk)); - logits_stream.write((const char *)tokens.data(), n_chunk*n_ctx*sizeof(tokens[0])); + std::vector log_probs; // the log probabilities of logits + + // rank 0 and last rank store log_probs + // only rank 0 or last device stores logits/log_probs + if (!params.logits_file.empty() && (is_last_dev || my_rank == 0)) { const int nv = 2*((n_vocab + 1)/2) + 4; log_probs.resize(n_ctx * nv); - } + // additional operations only for rank 0 or single device + if (my_rank == 0) { + // For single device, is_last_dev and my_rank==0 are both true + // For multiple devices, only rank 0 will write these headers + logits_stream.write((const char *)&n_vocab, sizeof(n_vocab)); + logits_stream.write((const char *)&n_chunk, sizeof(n_chunk)); + logits_stream.write((const char *)tokens.data(), n_chunk*n_ctx*sizeof(tokens[0])); + } + } // We get the logits for all the tokens in the context window (params.n_ctx) // from llama_eval above. Now, based on https://huggingface.co/docs/transformers/perplexity, // calculate the perplexity over the last half of the window (so the model always has @@ -576,43 +627,98 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par const auto t_start = std::chrono::high_resolution_clock::now(); + { + if (n_world > 1) { + sync_meta clear_meta; + clear_meta.clear_kv_cache = true; + + if (my_rank == 0) { + llama_send_meta(ctx, &clear_meta); + } else { + if (llama_recv_meta(ctx, &clear_meta) == -1) { + LOG_ERR("Failed to recv clear_kv_cache signal on rank %d\n", my_rank); + return {tokens, -1.0, {}, {}}; + } + if (!is_last_dev) { + llama_send_meta(ctx, &clear_meta); + } + } + } + } // clear the KV cache llama_kv_cache_clear(ctx); + sync_meta meta; + for (int j = 0; j < num_batches; ++j) { const int batch_start = start + j * n_batch; const int batch_size = std::min(end - batch_start, n_batch); + + int n_outputs = 0; - int n_outputs = 0; + // only rank 0 constructs the batch, other ranks just receive it + if (my_rank == 0){ - batch.n_tokens = 0; + batch.n_tokens = 0; - for (int seq = 0; seq < n_seq_batch; seq++) { - int seq_start = batch_start + seq*n_ctx; + for (int seq = 0; seq < n_seq_batch; seq++) { + int seq_start = batch_start + seq*n_ctx; - // save original token and restore it after eval - const auto token_org = tokens[seq_start]; + // save original token and restore it after eval + const auto token_org = tokens[seq_start]; - // add BOS token for the first batch of each chunk - if (add_bos && j == 0) { - tokens[seq_start] = llama_token_bos(llama_get_model(ctx)); + // add BOS token for the first batch of each chunk + if (add_bos && j == 0) { + tokens[seq_start] = llama_token_bos(llama_get_model(ctx)); + } + + for (int k = 0; k < batch_size; ++k) { + const int idx = seq*n_ctx + k; + + batch.token [idx] = tokens[seq_start + k]; + batch.pos [idx] = j*n_batch + k; + batch.n_seq_id[idx] = 1; + batch.seq_id [idx][0] = seq; + batch.logits [idx] = batch.pos[idx] >= first ? 1 : 0; + + n_outputs += batch.logits[idx] != 0; + + } + batch.n_tokens += batch_size; + + // restore the original token in case it was set to BOS + tokens[seq_start] = token_org; } + } - for (int k = 0; k < batch_size; ++k) { - const int idx = seq*n_ctx + k; + // other ranks need to know batch info + { + if (n_world > 1) { + meta.n_ctx = n_ctx; - batch.token [idx] = tokens[seq_start + k]; - batch.pos [idx] = j*n_batch + k; - batch.n_seq_id[idx] = 1; - batch.seq_id [idx][0] = seq; - batch.logits [idx] = batch.pos[idx] >= first ? 1 : 0; + if (my_rank == 0) { + // Required batch info: Operation scale, KV cache location, Logits calculation location + meta.n_tokens = batch.n_tokens; + meta.pos = batch.pos; + meta.logits = batch.logits; - n_outputs += batch.logits[idx] != 0; + meta.all_pos_0 = batch.all_pos_0; + meta.all_pos_1 = batch.all_pos_1; + + meta.n_outputs = n_outputs; + meta.chunk_start_pos = start; + + llama_send_meta(ctx, &meta); + } else { + if (llama_recv_meta(ctx, &meta) == -1) { + LOG_ERR("Failed to recv batch meta on rank %d\n", my_rank); + return {tokens, -1.0, {}, {}}; + } + if (!is_last_dev) { + llama_send_meta(ctx, &meta); + } + } } - batch.n_tokens += batch_size; - - // restore the original token in case it was set to BOS - tokens[seq_start] = token_org; } if (llama_decode(ctx, batch)) { @@ -620,14 +726,16 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par return {tokens, -1, logit_history, prob_history}; } - if (num_batches > 1 && n_outputs > 0) { - const auto * batch_logits = llama_get_logits(ctx); - logits.insert(logits.end(), batch_logits, batch_logits + n_outputs * n_vocab); + if (is_last_dev && num_batches > 1 ) { + const int n_outputs_synced = meta.n_outputs; + if (n_outputs_synced > 0) { + const auto * batch_logits = llama_get_logits(ctx); + logits.insert(logits.end(), batch_logits, batch_logits + n_outputs_synced * n_vocab); + } } } - - if (i == 0) { + if (my_rank == 0 && i == 0) { llama_synchronize(ctx); const auto t_end = std::chrono::high_resolution_clock::now(); const float t_total = std::chrono::duration(t_end - t_start).count(); @@ -640,53 +748,60 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par LOG("%.2f minutes\n", total_seconds / 60.0); } - for (int seq = 0; seq < n_seq_batch; seq++) { - const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits_ith(ctx, seq*n_ctx + first); + if (is_last_dev) { + for (int seq = 0; seq < n_seq_batch; seq++) { + const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits_ith(ctx, seq*n_ctx + first); + + int chunk_start_pos = meta.chunk_start_pos; + llama_token * tokens_data = tokens.data() + chunk_start_pos + seq*n_ctx + first; + if (!params.logits_file.empty()) { + process_logits(logits_stream, n_vocab, all_logits, + tokens_data, n_ctx - 1 - first, + workers, log_probs, nll, nll2); + } else { + process_logits(n_vocab, all_logits, + tokens_data, n_ctx - 1 - first, + workers, nll, nll2, + logit_history.data() + chunk_start_pos + seq*n_ctx + first, + prob_history.data() + chunk_start_pos + seq*n_ctx + first); + } + count += n_ctx - first - 1; - llama_token * tokens_data = tokens.data() + start + seq*n_ctx + first; - if (!params.logits_file.empty()) { - process_logits(logits_stream, n_vocab, all_logits, - tokens_data, n_ctx - 1 - first, - workers, log_probs, nll, nll2); - } else { - process_logits(n_vocab, all_logits, - tokens_data, n_ctx - 1 - first, - workers, nll, nll2, - logit_history.data() + start + seq*n_ctx + first, - prob_history.data() + start + seq*n_ctx + first); - } - count += n_ctx - first - 1; - - // perplexity is e^(average negative log-likelihood) - if (params.ppl_output_type == 0) { - LOG("[%d]%.4lf,", i + seq + 1, std::exp(nll / count)); - } else { - double av = nll/count; - double av2 = nll2/count - av*av; - if (av2 > 0) av2 = sqrt(av2/(count-1)); - LOG("%8d %.4lf %4lf %4lf\n", i*n_ctx, std::exp(nll / count), av, av2); + // perplexity is e^(average negative log-likelihood) + if (params.ppl_output_type == 0) { + LOG("[%d]%.4lf,", i + seq + 1, std::exp(nll / count)); + } else { + double av = nll/count; + double av2 = nll2/count - av*av; + if (av2 > 0) av2 = sqrt(av2/(count-1)); + LOG("%8d %.4lf %4lf %4lf\n", i*n_ctx, std::exp(nll / count), av, av2); + } } } - logits.clear(); } LOG("\n"); - nll2 /= count; - nll /= count; - const double ppl = exp(nll); - nll2 -= nll * nll; - if (nll2 > 0) { - nll2 = sqrt(nll2/(count-1)); - LOG_INF("Final estimate: PPL = %.4lf +/- %.5lf\n", ppl, nll2*ppl); - } else { - LOG_ERR("Unexpected negative standard deviation of log(prob)\n"); + if (is_last_dev) { + nll2 /= count; + nll /= count; + const double ppl = exp(nll); + nll2 -= nll * nll; + + if (nll2 > 0) { + nll2 = sqrt(nll2/(count-1)); + LOG_INF("Final estimate: PPL = %.4lf +/- %.5lf\n", ppl, nll2*ppl); + } else { + LOG_ERR("Unexpected negative standard deviation of log(prob)\n"); + } + + return {tokens, ppl, logit_history, prob_history}; } llama_batch_free(batch); - return {tokens, ppl, logit_history, prob_history}; -} + return {}; +} static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector & batch_logits, int32_t n_batch, int32_t n_vocab) { int prev_outputs = 0; @@ -1967,6 +2082,23 @@ int main(int argc, char ** argv) { return 1; } + uint32_t n_world = params.n_world; + uint32_t my_rank = params.rank; + GGML_ASSERT(!(n_world == 1 && my_rank > 0)); + + // check if --n-layer-window and --world is matched + if (my_rank == 0) { + uint32_t non_zero_count = 0; + size_t size = sizeof(params.n_layer_window) / sizeof(params.n_layer_window[0]); + for (size_t i = 0; i < size; ++i) { + if (params.n_layer_window[i] != 0) { + ++non_zero_count; + } + } + GGML_ASSERT((non_zero_count == 0 || non_zero_count == n_world) \ + && "Number of non-zero values in --n-layer-window must equal --world"); + } + gpt_init(); const int32_t n_ctx = params.n_ctx; @@ -2008,6 +2140,12 @@ int main(int argc, char ** argv) { // load the model and apply lora adapter, if any llama_init_result llama_init = llama_init_from_gpt_params(params); + // update rank and world size if any devices removed + my_rank = params.rank; + n_world = params.n_world; + + bool is_last_dev = (my_rank == n_world - 1); + llama_model * model = llama_init.model; llama_context * ctx = llama_init.context; if (model == NULL) { @@ -2028,6 +2166,13 @@ int main(int argc, char ** argv) { LOG_INF("%s\n", gpt_params_get_system_info(params).c_str()); } + char * stop_signal = nullptr; + std::thread signal_thread; + + if (my_rank != 0) { + signal_thread = std::thread(llama_free_sockets, ctx, &stop_signal); + } + struct results_perplexity results; if (params.hellaswag) { hellaswag_score(ctx, params); @@ -2042,9 +2187,11 @@ int main(int argc, char ** argv) { } LOG("\n"); - llama_perf_context_print(ctx); - - write_logfile(ctx, params, model, results); + + if (is_last_dev) { + llama_perf_context_print(ctx); + write_logfile(ctx, params, model, results); + } llama_free(ctx); llama_free_model(model); diff --git a/include/llama.h b/include/llama.h index 8bb8ac50..c8706488 100644 --- a/include/llama.h +++ b/include/llama.h @@ -48,6 +48,57 @@ #define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ #define LLAMA_STATE_SEQ_VERSION 2 + +typedef int32_t llama_pos; +typedef int32_t llama_seq_id; + +struct sync_meta { + // batch info + int32_t n_tokens = 0; + + int8_t * logits = nullptr; + llama_pos * pos = nullptr; + llama_pos all_pos_0; + llama_pos all_pos_1; + uint32_t n_ctx = 0; + + int chunk_start_pos; + int32_t n_outputs; // Used to pass the number of logits to be outputted + + // signal to clear the kv cache + bool clear_kv_cache= false; + + // signal to remove a kv cache sequence + bool kv_seq_rm = false; + llama_seq_id rm_seq_id = 0; + llama_pos rm_p0 = 0; + llama_pos rm_p1 = 0; + + // signal to add a kv cache sequence + bool kv_seq_add = false; + llama_seq_id add_seq_id = 0; + llama_pos add_p0 = 0; + llama_pos add_p1 = 0; + llama_pos add_delta = 0; + + // signal to copy a kv cache sequence + bool kv_seq_cp = false; + llama_seq_id cp_src_seq_id = 0; + llama_seq_id cp_dst_seq_id = 0; + llama_pos cp_p0 = 0; + llama_pos cp_p1 = 0; + + // signal to divide the kv cache range + bool kv_seq_div = false; + llama_seq_id div_seq_id = 0; + llama_pos div_p0 = 0; + llama_pos div_p1 = 0; + int div_factor = 1; + + // signal to transfer tokens_size + size_t tokens_size = 0; +}; + #ifdef __cplusplus extern "C" { #endif @@ -451,6 +502,8 @@ extern "C" { LLAMA_API void llama_init_sockets (struct llama_context * ctx, uint32_t n_world, uint32_t my_rank); LLAMA_API void llama_free_sockets (struct llama_context * ctx, char ** msg); + LLAMA_API int llama_recv_meta (struct llama_context * ctx, struct sync_meta * meta); + LLAMA_API void llama_send_meta (struct llama_context * ctx, struct sync_meta * meta); LLAMA_API int llama_gather_device_info(struct llama_context * ctx, struct device_info * dev_info_set); LLAMA_API int llama_send_device_info (struct llama_context * ctx, struct device_info * dev_info); LLAMA_API int llama_bcast_startup_args(struct llama_context * ctx, uint32_t rank, struct startup_args * args); diff --git a/src/llama.cpp b/src/llama.cpp index 321d9aaf..9428e3bd 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -17866,46 +17866,13 @@ struct input_tensors { ggml_tensor * inp_pos; }; -struct sync_meta { - int32_t n_tokens = 0; - llama_pos * pos = nullptr; - llama_pos all_pos_0; - llama_pos all_pos_1; - uint32_t n_ctx = 0; - - // signal to clear the kv cache - bool clear_kv_cache = false; - - // signal to remove a kv cache sequence - bool kv_seq_rm = false; - llama_seq_id rm_seq_id = 0; - llama_pos rm_p0 = 0; - llama_pos rm_p1 = 0; - - // signal to add a kv cache sequence - bool kv_seq_add = false; - llama_seq_id add_seq_id = 0; - llama_pos add_p0 = 0; - llama_pos add_p1 = 0; - llama_pos add_delta = 0; - - // signal to copy a kv cache sequence - bool kv_seq_cp = false; - llama_seq_id cp_src_seq_id = 0; - llama_seq_id cp_dst_seq_id = 0; - llama_pos cp_p0 = 0; - llama_pos cp_p1 = 0; - - // signal to divide the kv cache range - bool kv_seq_div = false; - llama_seq_id div_seq_id = 0; - llama_pos div_p0 = 0; - llama_pos div_p1 = 0; - int div_factor = 1; -}; - -static void llama_send_meta(zmq::socket_t & socket, struct sync_meta * meta) { +void llama_send_meta(llama_context * ctx, struct sync_meta * meta) { + GGML_ASSERT(ctx != nullptr); GGML_ASSERT(meta != nullptr); + + zmq::socket_t * send_socket = ctx->send_socket; + GGML_ASSERT(send_socket != nullptr); + try { std::vector send_msgs; @@ -17924,21 +17891,24 @@ static void llama_send_meta(zmq::socket_t & socket, struct sync_meta * meta) { send_msgs.emplace_back("all_pos_1", strlen("all_pos_1")); send_msgs.emplace_back(&(meta->all_pos_1), sizeof(meta->all_pos_1)); - zmq::send_multipart(socket, send_msgs); + if (!send_msgs.empty()) { + zmq::send_multipart(*send_socket, send_msgs); + } } catch (const zmq::error_t& e) { LLAMA_LOG_INFO("Failed to send meta data: %s\n", e.what()); } } -static int llama_recv_meta(zmq::socket_t & socket, struct sync_meta * meta) { - socket.set(zmq::sockopt::rcvtimeo, 1000); +int llama_recv_meta(llama_context * ctx, struct sync_meta * meta) { + ctx->recv_socket->set(zmq::sockopt::rcvtimeo, 1000); std::vector recv_msgs; - if (!zmq::recv_multipart(socket, std::back_inserter(recv_msgs))) { + + if (!zmq::recv_multipart(*(ctx->recv_socket), std::back_inserter(recv_msgs))) { return -1; } - socket.set(zmq::sockopt::rcvtimeo, -1); + ctx->recv_socket->set(zmq::sockopt::rcvtimeo, -1); const std::string cmd = recv_msgs[0].to_string(); size_t idx = 1; @@ -18210,11 +18180,6 @@ static void manage_graph_tensors(struct ggml_cgraph * cgraph, int advice, bool f static int llama_decode_internal( llama_context & lctx, llama_batch & batch_all) { // TODO: rename back to batch - - // llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : lctx.kv_self.get_pos_max() + 1); - - // // const llama_batch & batch_all = batch_allocr.batch; - // llama_batch & batch_all = batch_allocr.batch; lctx.is_encoding = false; @@ -18277,13 +18242,13 @@ static int llama_decode_internal( n_outputs = 1; } - // TODO:needs to be encapsulated into a function + // prepare for send and receive of metadata sync_meta meta; meta.n_ctx = cparams.n_ctx; bool is_last_dev = (my_rank == n_world - 1); if (my_rank != 0) { - if (llama_recv_meta(*lctx.recv_socket, &meta) == -1) { + if (llama_recv_meta(&lctx, &meta) == -1) { return -1; } @@ -18343,7 +18308,7 @@ static int llama_decode_internal( meta.pos = batch_all.pos; meta.all_pos_0 = batch_all.all_pos_0; meta.all_pos_1 = batch_all.all_pos_1; - llama_send_meta(*lctx.send_socket, &meta); + llama_send_meta(&lctx, &meta); } lctx.sbatch.from_batch(batch_all, n_embd,