Modify the perplexity test to a distributed version

This commit is contained in:
leeetao 2025-06-18 07:05:53 +00:00
parent 32e1088162
commit 2123879cfe
4 changed files with 314 additions and 149 deletions

View file

@ -48,6 +48,57 @@
#define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ
#define LLAMA_STATE_SEQ_VERSION 2
typedef int32_t llama_pos;
typedef int32_t llama_seq_id;
struct sync_meta {
// batch info
int32_t n_tokens = 0;
int8_t * logits = nullptr;
llama_pos * pos = nullptr;
llama_pos all_pos_0;
llama_pos all_pos_1;
uint32_t n_ctx = 0;
int chunk_start_pos;
int32_t n_outputs; // Used to pass the number of logits to be outputted
// signal to clear the kv cache
bool clear_kv_cache= false;
// signal to remove a kv cache sequence
bool kv_seq_rm = false;
llama_seq_id rm_seq_id = 0;
llama_pos rm_p0 = 0;
llama_pos rm_p1 = 0;
// signal to add a kv cache sequence
bool kv_seq_add = false;
llama_seq_id add_seq_id = 0;
llama_pos add_p0 = 0;
llama_pos add_p1 = 0;
llama_pos add_delta = 0;
// signal to copy a kv cache sequence
bool kv_seq_cp = false;
llama_seq_id cp_src_seq_id = 0;
llama_seq_id cp_dst_seq_id = 0;
llama_pos cp_p0 = 0;
llama_pos cp_p1 = 0;
// signal to divide the kv cache range
bool kv_seq_div = false;
llama_seq_id div_seq_id = 0;
llama_pos div_p0 = 0;
llama_pos div_p1 = 0;
int div_factor = 1;
// signal to transfer tokens_size
size_t tokens_size = 0;
};
#ifdef __cplusplus
extern "C" {
#endif
@ -451,6 +502,8 @@ extern "C" {
LLAMA_API void llama_init_sockets (struct llama_context * ctx, uint32_t n_world, uint32_t my_rank);
LLAMA_API void llama_free_sockets (struct llama_context * ctx, char ** msg);
LLAMA_API int llama_recv_meta (struct llama_context * ctx, struct sync_meta * meta);
LLAMA_API void llama_send_meta (struct llama_context * ctx, struct sync_meta * meta);
LLAMA_API int llama_gather_device_info(struct llama_context * ctx, struct device_info * dev_info_set);
LLAMA_API int llama_send_device_info (struct llama_context * ctx, struct device_info * dev_info);
LLAMA_API int llama_bcast_startup_args(struct llama_context * ctx, uint32_t rank, struct startup_args * args);