#include "arg.h" #include "common.h" #include "log.h" #include "llama-cpp.h" #include #include struct llama_batch_ptr { llama_batch batch; llama_batch_ptr(int32_t n_tokens, int32_t embd, int32_t n_seq_max) : batch{llama_batch_init(n_tokens, embd, n_seq_max)} {} ~llama_batch_ptr() { llama_batch_free(batch); } llama_batch_ptr(const llama_batch_ptr &) = delete; llama_batch_ptr & operator=(const llama_batch_ptr &) = delete; llama_batch_ptr(llama_batch_ptr &&) = default; llama_batch_ptr & operator=(llama_batch_ptr &&) = default; llama_batch & get() { return batch; } const llama_batch & get() const { return batch; } }; static std::string generate_tokens(llama_context * ctx, llama_sampler * smpl, int & n_past, int32_t n_predict, llama_seq_id seq_id) { std::string result; llama_batch_ptr batch(1, 0, 1); for (int i = 0; i < n_predict; i++) { auto next_token = llama_sampler_sample(smpl, ctx, -1); auto next_token_str = common_token_to_piece(ctx, next_token); LOG("%s", next_token_str.c_str()); result += next_token_str; common_batch_clear(batch.get()); common_batch_add(batch.get(), next_token, n_past, {seq_id}, true); if (llama_decode(ctx, batch.get())) { LOG_ERR("\n%s: failed to evaluate\n", __func__); return {}; } n_past++; } return result; } // Test 1: baseline // - tokenize the prompt // - decode all but the last token // - save state to disk // - decode the last token // - generate n_predict tokens static std::string test_baseline(struct llama_model * model, const struct common_params & params) { auto ctx = llama_context_ptr{llama_init_from_model(model, common_context_params_to_llama(params))}; auto sparams = llama_sampler_chain_default_params(); auto smpl = llama_sampler_ptr{llama_sampler_chain_init(sparams)}; llama_sampler_chain_add(smpl.get(), llama_sampler_init_dist(params.sampling.seed)); auto tokens = common_tokenize(ctx.get(), params.prompt, true); auto n_past = 0; if (!common_prompt_batch_decode(ctx.get(), tokens, n_past, params.n_batch, params.out_file, true)) { LOG_ERR("%s: failed to decode prompt\n", __func__); return {}; } LOG("\n=== Test 1: baseline ===\n"); LOG("%s", params.prompt.c_str()); auto result = generate_tokens(ctx.get(), smpl.get(), n_past, params.n_predict, 0); if (result.empty()) { return {}; } LOG("\n"); return result; } // Test 2: state load // - create a new context // - load state from file // - replay the last prompt token // - generate n_predict tokens and compare against expected result static bool test_state_load(struct llama_model * model, const struct common_params & params, const std::string & expected_result) { auto ctx = llama_context_ptr{llama_init_from_model(model, common_context_params_to_llama(params))}; auto sparams = llama_sampler_chain_default_params(); auto smpl = llama_sampler_ptr{llama_sampler_chain_init(sparams)}; llama_sampler_chain_add(smpl.get(), llama_sampler_init_dist(params.sampling.seed)); auto tokens = common_tokenize(ctx.get(), params.prompt, true); LOG("\n=== Test 2: state load ===\n"); LOG("%s", params.prompt.c_str()); // Load state from file std::vector unused_sts(tokens.size()); size_t n_token_count_out = 0; if (!llama_state_load_file(ctx.get(), params.out_file.data(), unused_sts.data(), unused_sts.size(), &n_token_count_out)) { LOG_ERR("\n%s: failed to load state\n", __func__); return false; } LOG_TRC("%s: loaded state with %zu tokens\n", __func__, n_token_count_out); // Replay last token int n_past = (int) n_token_count_out; if (!common_replay_last_token(ctx.get(), tokens.back(), n_past)) { return false; } n_past++; // Generate tokens auto result = generate_tokens(ctx.get(), smpl.get(), n_past, params.n_predict, 0); if (result.empty()) { return false; } if (result != expected_result) { LOG_ERR("\n%s: error: generation differs from expected\n", __func__); return false; } LOG("\nPASS\n"); return true; } // Test 3: seq copy (host) // - create a multi-seq context // - load state from file // - replay the last prompt token // - migrate KV cache from seq 0 to seq 1 via the CPU path // - generate n_predict tokens on seq 1 and compare against expected result static bool test_seq_cp_host(struct llama_model * model, const struct common_params & params, const std::string & expected_result) { auto params_ctx = common_context_params_to_llama(params); params_ctx.n_seq_max = 2; auto ctx = llama_context_ptr{llama_init_from_model(model, params_ctx)}; auto sparams = llama_sampler_chain_default_params(); auto smpl = llama_sampler_ptr{llama_sampler_chain_init(sparams)}; llama_sampler_chain_add(smpl.get(), llama_sampler_init_dist(params.sampling.seed)); auto tokens = common_tokenize(ctx.get(), params.prompt, true); LOG("\n=== Test 3: seq copy (host) ===\n"); LOG("%s", params.prompt.c_str()); // Load state from file std::vector unused_sts(tokens.size()); size_t n_token_count_out = 0; if (!llama_state_load_file(ctx.get(), params.out_file.data(), unused_sts.data(), unused_sts.size(), &n_token_count_out)) { LOG_ERR("\n%s: failed to load state\n", __func__); return false; } LOG_TRC("%s: loaded state with %zu tokens\n", __func__, n_token_count_out); // Replay last token int n_past = (int) n_token_count_out; if (!common_replay_last_token(ctx.get(), tokens.back(), n_past)) { return false; } n_past++; // Migrate KV cache from seq 0 to seq 1 (CPU path) { std::vector seq_store(llama_state_seq_get_size(ctx.get(), 0)); const size_t ncopy = llama_state_seq_get_data(ctx.get(), seq_store.data(), seq_store.size(), 0); if (ncopy != seq_store.size()) { LOG_ERR("\n%s: seq copy data length %zd does not match expected length %zd\n", __func__, ncopy, seq_store.size()); return false; } LOG_TRC("%s: seq 0 copied, %zd bytes\n", __func__, ncopy); llama_memory_clear(llama_get_memory(ctx.get()), true); LOG_TRC("%s: kv cache cleared\n", __func__); const size_t nset = llama_state_seq_set_data(ctx.get(), seq_store.data(), seq_store.size(), 1); if (nset != seq_store.size()) { LOG_ERR("\n%s: seq set data length %zd does not match expected length %zd\n", __func__, nset, seq_store.size()); return false; } LOG_TRC("%s: seq 1 restored, %zd bytes\n", __func__, nset); } // Generate tokens on seq 1 auto result = generate_tokens(ctx.get(), smpl.get(), n_past, params.n_predict, 1); if (result.empty()) { return false; } if (result != expected_result) { LOG_ERR("\n%s: error: generation differs from expected\n", __func__); return false; } LOG("\nPASS\n"); return true; } // Test 4: seq copy (device) // - create a multi-seq context // - load state from file // - replay the last prompt token // - migrate KV cache from seq 0 to seq 1 via the on-device path // - generate n_predict tokens on seq 1 and compare against expected result static bool test_seq_cp_device(struct llama_model * model, const struct common_params & params, const std::string & expected_result) { auto params_ctx = common_context_params_to_llama(params); params_ctx.n_seq_max = 2; auto ctx = llama_context_ptr{llama_init_from_model(model, params_ctx)}; auto sparams = llama_sampler_chain_default_params(); auto smpl = llama_sampler_ptr{llama_sampler_chain_init(sparams)}; llama_sampler_chain_add(smpl.get(), llama_sampler_init_dist(params.sampling.seed)); auto tokens = common_tokenize(ctx.get(), params.prompt, true); LOG("\n=== Test 4: seq copy (device) ===\n"); LOG("%s", params.prompt.c_str()); // Load state from file std::vector unused_sts(tokens.size()); size_t n_token_count_out = 0; if (!llama_state_load_file(ctx.get(), params.out_file.data(), unused_sts.data(), unused_sts.size(), &n_token_count_out)) { LOG_ERR("\n%s: failed to load state\n", __func__); return false; } LOG_TRC("%s: loaded state with %zu tokens\n", __func__, n_token_count_out); // Replay last token int n_past = (int) n_token_count_out; if (!common_replay_last_token(ctx.get(), tokens.back(), n_past)) { return false; } n_past++; // Migrate KV cache from seq 0 to seq 1 (on-device path) { std::vector seq_store(llama_state_seq_get_size_ext(ctx.get(), 0, LLAMA_STATE_SEQ_FLAGS_ON_DEVICE)); const size_t ncopy = llama_state_seq_get_data_ext(ctx.get(), seq_store.data(), seq_store.size(), 0, LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); if (ncopy != seq_store.size()) { LOG_ERR("\n%s: seq copy data length %zd does not match expected length %zd\n", __func__, ncopy, seq_store.size()); return false; } LOG_TRC("%s: seq 0 copied, %zd bytes\n", __func__, ncopy); llama_memory_clear(llama_get_memory(ctx.get()), true); LOG_TRC("%s: kv cache cleared\n", __func__); const size_t nset = llama_state_seq_set_data_ext(ctx.get(), seq_store.data(), seq_store.size(), 1, LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); if (nset != seq_store.size()) { LOG_ERR("\n%s: seq set data length %zd does not match expected length %zd\n", __func__, nset, seq_store.size()); return false; } LOG_TRC("%s: seq 1 restored, %zd bytes\n", __func__, nset); } // Generate tokens on seq 1 auto result = generate_tokens(ctx.get(), smpl.get(), n_past, params.n_predict, 1); if (result.empty()) { return false; } if (result != expected_result) { LOG_ERR("\n%s: error: generation differs from expected\n", __func__); return false; } LOG("\nPASS\n"); return true; } int main(int argc, char ** argv) { std::setlocale(LC_NUMERIC, "C"); common_params params; params.prompt = "The quick brown fox"; params.out_file = "dump_state.bin"; params.sampling.seed = 1234; common_init(); if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) { return 1; } if (params.n_parallel == 1) { LOG_TRC("%s: n_parallel == 1, enabling unified kv cache\n", __func__); params.kv_unified = true; } if (params.n_predict < 0) { params.n_predict = 16; } ggml_backend_load_all(); auto llama_init = common_init_from_params(params, true); auto * model = llama_init->model(); if (model == nullptr) { LOG_ERR("%s: failed to init\n", __func__); return 1; } GGML_ASSERT(llama_init->context() == nullptr); // Test 1: baseline (saves state to disk) auto result_baseline = test_baseline(model, params); if (result_baseline.empty()) { return 1; } // Test 2: state load if (!test_state_load(model, params, result_baseline)) { return 1; } // Test 3: seq copy (host) if (!test_seq_cp_host(model, params, result_baseline)) { return 1; } // Test 4: seq copy (device) if (!test_seq_cp_device(model, params, result_baseline)) { return 1; } LOG("\nAll tests passed.\n"); return 0; }