Refactored the logic related to communication content and timing control

This commit is contained in:
leeetao 2025-06-24 10:40:37 +00:00
parent 4b823775ec
commit a3becb586a
5 changed files with 474 additions and 134 deletions

View file

@ -58,12 +58,15 @@ struct sync_meta {
int8_t * logits = nullptr;
llama_pos * pos = nullptr;
int32_t * n_seq_id = nullptr;
llama_seq_id ** seq_id = nullptr;
llama_pos all_pos_0;
llama_pos all_pos_1;
uint32_t n_ctx = 0;
int chunk_start_pos;
int32_t n_outputs; // Used to pass the number of logits to be outputted
// used for perplexity evaluation
int32_t n_outputs;
bool chunk_done = false; // signal that the chunk is done
// signal to clear the kv cache
bool clear_kv_cache= false;
@ -389,6 +392,7 @@ extern "C" {
int32_t n_threads; // number of threads to use for generation
int32_t n_threads_batch; // number of threads to use for batch processing
enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id
enum llama_attention_type attention_type; // attention type to use for embeddings
@ -422,6 +426,7 @@ extern "C" {
// currently works only with CPU execution
ggml_abort_callback abort_callback;
void * abort_callback_data;
bool is_perplexity_eval; // whether to run in perplexity evaluation mode
};
// model quantization parameters
@ -502,8 +507,8 @@ extern "C" {
LLAMA_API void llama_init_sockets (struct llama_context * ctx, uint32_t n_world, uint32_t my_rank);
LLAMA_API void llama_free_sockets (struct llama_context * ctx, char ** msg);
LLAMA_API int llama_recv_meta (struct llama_context * ctx, struct sync_meta * meta);
LLAMA_API void llama_send_meta (struct llama_context * ctx, struct sync_meta * meta);
LLAMA_API int llama_recv_meta (struct llama_context * ctx, struct sync_meta * meta, bool reverse);
LLAMA_API void llama_send_meta (struct llama_context * ctx, struct sync_meta * meta, bool reverse);
LLAMA_API int llama_gather_device_info(struct llama_context * ctx, struct device_info * dev_info_set);
LLAMA_API int llama_send_device_info (struct llama_context * ctx, struct device_info * dev_info);
LLAMA_API int llama_bcast_startup_args(struct llama_context * ctx, uint32_t rank, struct startup_args * args);