mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-13 02:19:41 +00:00
not working correctly
This commit is contained in:
commit
6ce85c54d6
30 changed files with 1603 additions and 724 deletions
|
@ -1350,9 +1350,9 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
));
|
||||
add_opt(common_arg(
|
||||
{"--prio"}, "N",
|
||||
string_format("set process/thread priority : 0-normal, 1-medium, 2-high, 3-realtime (default: %d)\n", params.cpuparams.priority),
|
||||
string_format("set process/thread priority : low(-1), normal(0), medium(1), high(2), realtime(3) (default: %d)\n", params.cpuparams.priority),
|
||||
[](common_params & params, int prio) {
|
||||
if (prio < 0 || prio > 3) {
|
||||
if (prio < GGML_SCHED_PRIO_LOW || prio > GGML_SCHED_PRIO_REALTIME) {
|
||||
throw std::invalid_argument("invalid value");
|
||||
}
|
||||
params.cpuparams.priority = (enum ggml_sched_priority) prio;
|
||||
|
|
|
@ -154,9 +154,10 @@ bool common_chat_msg_parser::try_parse_reasoning(const std::string & start_think
|
|||
if (!rest.empty()) {
|
||||
handle_reasoning(rest, /* closed */ !is_partial());
|
||||
}
|
||||
if (!syntax_.thinking_forced_open) {
|
||||
throw common_chat_msg_partial_exception(end_think);
|
||||
}
|
||||
// Allow unclosed thinking tags, for now (https://github.com/ggml-org/llama.cpp/issues/13812, https://github.com/ggml-org/llama.cpp/issues/13877)
|
||||
// if (!syntax_.thinking_forced_open) {
|
||||
// throw common_chat_msg_partial_exception(end_think);
|
||||
// }
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -211,6 +211,7 @@ bool set_process_priority(enum ggml_sched_priority prio) {
|
|||
|
||||
DWORD p = NORMAL_PRIORITY_CLASS;
|
||||
switch (prio) {
|
||||
case GGML_SCHED_PRIO_LOW: p = BELOW_NORMAL_PRIORITY_CLASS; break;
|
||||
case GGML_SCHED_PRIO_NORMAL: p = NORMAL_PRIORITY_CLASS; break;
|
||||
case GGML_SCHED_PRIO_MEDIUM: p = ABOVE_NORMAL_PRIORITY_CLASS; break;
|
||||
case GGML_SCHED_PRIO_HIGH: p = HIGH_PRIORITY_CLASS; break;
|
||||
|
@ -236,6 +237,7 @@ bool set_process_priority(enum ggml_sched_priority prio) {
|
|||
|
||||
int p = 0;
|
||||
switch (prio) {
|
||||
case GGML_SCHED_PRIO_LOW: p = 5; break;
|
||||
case GGML_SCHED_PRIO_NORMAL: p = 0; break;
|
||||
case GGML_SCHED_PRIO_MEDIUM: p = -5; break;
|
||||
case GGML_SCHED_PRIO_HIGH: p = -10; break;
|
||||
|
|
1
expose.h
1
expose.h
|
@ -69,6 +69,7 @@ struct load_model_inputs
|
|||
const int quant_k = 0;
|
||||
const int quant_v = 0;
|
||||
const bool check_slowness = false;
|
||||
const bool highpriority = false;
|
||||
const bool swa_support = false;
|
||||
const float lora_multiplier = 1.0f;
|
||||
const bool quiet = false;
|
||||
|
|
|
@ -2194,6 +2194,7 @@ extern "C" {
|
|||
|
||||
// scheduling priorities
|
||||
enum ggml_sched_priority {
|
||||
GGML_SCHED_PRIO_LOW = -1,
|
||||
GGML_SCHED_PRIO_NORMAL,
|
||||
GGML_SCHED_PRIO_MEDIUM,
|
||||
GGML_SCHED_PRIO_HIGH,
|
||||
|
|
|
@ -2427,17 +2427,46 @@ static bool ggml_thread_apply_affinity(bool * mask) {
|
|||
return m != 0;
|
||||
}
|
||||
|
||||
static bool powethrottlemsgshown = false;
|
||||
static bool ggml_thread_apply_priority(int32_t prio) {
|
||||
// Note that on Windows the Process Priority Class must be updated in order to set Thread priority.
|
||||
// This is up to the applications.
|
||||
DWORD p = THREAD_PRIORITY_NORMAL;
|
||||
switch (prio) {
|
||||
case GGML_SCHED_PRIO_LOW: p = THREAD_PRIORITY_BELOW_NORMAL; break;
|
||||
case GGML_SCHED_PRIO_NORMAL: p = THREAD_PRIORITY_NORMAL; break;
|
||||
case GGML_SCHED_PRIO_MEDIUM: p = THREAD_PRIORITY_ABOVE_NORMAL; break;
|
||||
case GGML_SCHED_PRIO_HIGH: p = THREAD_PRIORITY_HIGHEST; break;
|
||||
case GGML_SCHED_PRIO_REALTIME: p = THREAD_PRIORITY_TIME_CRITICAL; break;
|
||||
}
|
||||
|
||||
#ifndef USE_FAILSAFE
|
||||
if (prio != GGML_SCHED_PRIO_LOW) {
|
||||
// Tell Windows that this thread should not be throttled (needs its own CPU core).
|
||||
// Newer Windows 11 versions aggresively park (offline) CPU cores and often place
|
||||
// all our threads onto the first 4 cores which results in terrible performance with
|
||||
// n_threads > 4
|
||||
#if _WIN32_WINNT >= 0x602
|
||||
PROCESS_POWER_THROTTLING_STATE t;
|
||||
ZeroMemory(&t, sizeof(t));
|
||||
t.Version = PROCESS_POWER_THROTTLING_CURRENT_VERSION;
|
||||
t.ControlMask = PROCESS_POWER_THROTTLING_EXECUTION_SPEED;
|
||||
t.StateMask = 0;
|
||||
|
||||
if (!SetProcessInformation(GetCurrentProcess(), ProcessPowerThrottling, &t, sizeof(t))) {
|
||||
GGML_LOG_DEBUG("failed to disable process power throttling %d : (%d)\n", prio, (int) GetLastError());
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
#else
|
||||
if(!powethrottlemsgshown)
|
||||
{
|
||||
powethrottlemsgshown = true;
|
||||
printf("\nPower Throttling skipped in compatibility mode.\n");
|
||||
}
|
||||
#endif
|
||||
|
||||
if (prio == GGML_SCHED_PRIO_NORMAL) {
|
||||
// Keep inherited policy/priority
|
||||
return true;
|
||||
|
@ -2465,6 +2494,8 @@ static bool ggml_thread_apply_priority(int32_t prio) {
|
|||
struct sched_param p;
|
||||
int32_t policy = SCHED_OTHER;
|
||||
switch (prio) {
|
||||
// TODO: there seems to be no way to set lower prio on Apple platforms
|
||||
case GGML_SCHED_PRIO_LOW: policy = SCHED_OTHER; p.sched_priority = 0; break;
|
||||
case GGML_SCHED_PRIO_NORMAL: policy = SCHED_OTHER; p.sched_priority = 0; break;
|
||||
case GGML_SCHED_PRIO_MEDIUM: policy = SCHED_FIFO; p.sched_priority = 40; break;
|
||||
case GGML_SCHED_PRIO_HIGH: policy = SCHED_FIFO; p.sched_priority = 80; break;
|
||||
|
@ -2521,6 +2552,7 @@ static bool ggml_thread_apply_priority(int32_t prio) {
|
|||
struct sched_param p;
|
||||
int32_t policy = SCHED_OTHER;
|
||||
switch (prio) {
|
||||
case GGML_SCHED_PRIO_LOW: policy = SCHED_BATCH; p.sched_priority = 0; break;
|
||||
case GGML_SCHED_PRIO_NORMAL: policy = SCHED_OTHER; p.sched_priority = 0; break;
|
||||
case GGML_SCHED_PRIO_MEDIUM: policy = SCHED_FIFO; p.sched_priority = 40; break;
|
||||
case GGML_SCHED_PRIO_HIGH: policy = SCHED_FIFO; p.sched_priority = 80; break;
|
||||
|
|
|
@ -635,6 +635,7 @@ struct ggml_cuda_device_info {
|
|||
int nsm; // number of streaming multiprocessors
|
||||
size_t smpb; // max. shared memory per block
|
||||
size_t smpbo; // max. shared memory per block (with opt-in)
|
||||
bool integrated; // Device is integrated as opposed to discrete
|
||||
bool vmm; // virtual memory support
|
||||
size_t vmm_granularity; // granularity of virtual memory
|
||||
size_t total_vram;
|
||||
|
|
|
@ -246,10 +246,10 @@ static ggml_cuda_device_info ggml_cuda_init() {
|
|||
|
||||
info.default_tensor_split[id] = total_vram;
|
||||
total_vram += prop.totalGlobalMem;
|
||||
|
||||
info.devices[id].nsm = prop.multiProcessorCount;
|
||||
info.devices[id].smpb = prop.sharedMemPerBlock;
|
||||
info.devices[id].warp_size = prop.warpSize;
|
||||
info.devices[id].integrated = prop.integrated;
|
||||
info.devices[id].nsm = prop.multiProcessorCount;
|
||||
info.devices[id].smpb = prop.sharedMemPerBlock;
|
||||
info.devices[id].warp_size = prop.warpSize;
|
||||
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
||||
info.devices[id].smpbo = prop.sharedMemPerBlock;
|
||||
|
||||
|
@ -1066,6 +1066,10 @@ static const char * ggml_backend_cuda_host_buffer_type_name(ggml_backend_buffer_
|
|||
GGML_UNUSED(buft);
|
||||
}
|
||||
|
||||
static bool ggml_backend_buft_is_cuda_host(ggml_backend_buffer_type_t buft) {
|
||||
return buft->iface.get_name == ggml_backend_cuda_host_buffer_type_name;
|
||||
}
|
||||
|
||||
static void ggml_backend_cuda_host_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
||||
CUDA_CHECK(cudaFreeHost(buffer->context));
|
||||
}
|
||||
|
@ -2646,6 +2650,8 @@ static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
|
|||
|
||||
static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
|
||||
bool & graph_evaluated_or_captured, bool & use_cuda_graph, bool & cuda_graph_update_required) {
|
||||
// flag used to determine whether it is an integrated_gpu
|
||||
const bool integrated = ggml_cuda_info().devices[cuda_ctx->device].integrated;
|
||||
|
||||
while (!graph_evaluated_or_captured) {
|
||||
// Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph.
|
||||
|
@ -2664,7 +2670,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
|
|||
if (node->src[j] != nullptr) {
|
||||
assert(node->src[j]->buffer);
|
||||
assert(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) ||
|
||||
ggml_backend_buft_is_cuda_split(node->src[j]->buffer->buft));
|
||||
ggml_backend_buft_is_cuda_split(node->src[j]->buffer->buft) || (integrated && ggml_backend_buft_is_cuda_host(node->src[j]->buffer->buft)));
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
@ -3271,7 +3277,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|||
}
|
||||
|
||||
static bool ggml_backend_cuda_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
|
||||
return (ggml_backend_buft_is_cuda(buft) || ggml_backend_buft_is_cuda_split(buft)) && buft->device == dev;
|
||||
ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) dev->context;
|
||||
const bool integrated = ggml_cuda_info().devices[dev_ctx->device].integrated;
|
||||
return (((ggml_backend_buft_is_cuda(buft) || ggml_backend_buft_is_cuda_split(buft)) && buft->device == dev) || (integrated && ggml_backend_buft_is_cuda_host(buft)));
|
||||
}
|
||||
|
||||
static int64_t get_op_batch_size(const ggml_tensor * op) {
|
||||
|
|
|
@ -136,6 +136,7 @@ static std::vector<logit_bias> logit_biases;
|
|||
static bool add_bos_token = true; // if set to false, mmproj handling breaks. dont disable unless you know what you're doing
|
||||
static bool load_guidance = false; //whether to enable cfg for negative prompts
|
||||
static bool check_slowness = false; //will display a suggestion to use highpriority if slow
|
||||
static bool highpriority = false;
|
||||
|
||||
static int delayed_generated_tokens_limit = 0;
|
||||
std::deque<std::string> delayed_generated_tokens; //for use with antislop sampling
|
||||
|
@ -1972,6 +1973,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
|||
add_bos_token = !inputs.no_bos_token;
|
||||
load_guidance = inputs.load_guidance;
|
||||
check_slowness = inputs.check_slowness;
|
||||
highpriority = inputs.highpriority;
|
||||
|
||||
if(!add_bos_token)
|
||||
{
|
||||
|
@ -2356,6 +2358,11 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
|||
ggml_threadpool_params threadpool1_params, threadpool2_params;
|
||||
ggml_threadpool_params_init(&threadpool1_params,kcpp_data->n_threads);
|
||||
ggml_threadpool_params_init(&threadpool2_params,kcpp_data->n_blasthreads);
|
||||
if(inputs.highpriority)
|
||||
{
|
||||
threadpool1_params.prio = GGML_SCHED_PRIO_HIGH;
|
||||
threadpool2_params.prio = GGML_SCHED_PRIO_HIGH;
|
||||
}
|
||||
|
||||
printf("Threadpool set to %d threads and %d blasthreads...\n", kcpp_data->n_threads,kcpp_data->n_blasthreads);
|
||||
struct ggml_threadpool * threadpool1 = ggml_threadpool_new(&threadpool1_params);
|
||||
|
|
|
@ -262,9 +262,9 @@ extern "C" {
|
|||
llama_token * token;
|
||||
float * embd;
|
||||
llama_pos * pos;
|
||||
int32_t * n_seq_id;
|
||||
llama_seq_id ** seq_id;
|
||||
int8_t * logits; // TODO: rename this to "output"
|
||||
int32_t * n_seq_id; // TODO: remove, should belong to only 1 sequence
|
||||
llama_seq_id ** seq_id; // TODO: become llama_seq_id * seq_id;
|
||||
int8_t * logits; // TODO: rename this to "output"
|
||||
} llama_batch;
|
||||
|
||||
enum llama_model_kv_override_type {
|
||||
|
@ -369,6 +369,8 @@ extern "C" {
|
|||
bool no_perf; // measure performance timings
|
||||
bool op_offload; // offload host tensor operations to device
|
||||
bool swa_full; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
|
||||
// NOTE: setting to false when n_seq_max > 1 can cause bad performance in some cases
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573
|
||||
};
|
||||
|
||||
// model quantization parameters
|
||||
|
@ -505,6 +507,7 @@ extern "C" {
|
|||
LLAMA_API int32_t llama_model_n_layer (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_head (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_swa (const struct llama_model * model);
|
||||
|
||||
// Get the model's RoPE frequency scaling factor
|
||||
LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model);
|
||||
|
@ -655,7 +658,6 @@ extern "C" {
|
|||
// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
|
||||
// If the KV cache is RoPEd, the KV data is updated accordingly:
|
||||
// - lazily on next llama_decode()
|
||||
// - explicitly with llama_kv_self_update()
|
||||
// p0 < 0 : [0, p1]
|
||||
// p1 < 0 : [p0, inf)
|
||||
LLAMA_API void llama_kv_self_seq_add(
|
||||
|
@ -668,7 +670,6 @@ extern "C" {
|
|||
// Integer division of the positions by factor of `d > 1`
|
||||
// If the KV cache is RoPEd, the KV data is updated accordingly:
|
||||
// - lazily on next llama_decode()
|
||||
// - explicitly with llama_kv_self_update()
|
||||
// p0 < 0 : [0, p1]
|
||||
// p1 < 0 : [p0, inf)
|
||||
LLAMA_API void llama_kv_self_seq_div(
|
||||
|
@ -680,12 +681,14 @@ extern "C" {
|
|||
|
||||
// Returns the smallest position present in the KV cache for the specified sequence
|
||||
// This is typically non-zero only for SWA caches
|
||||
// Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache
|
||||
// Return -1 if the sequence is empty
|
||||
LLAMA_API llama_pos llama_kv_self_seq_pos_min(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id);
|
||||
|
||||
// Returns the largest position present in the KV cache for the specified sequence
|
||||
// Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache
|
||||
// Return -1 if the sequence is empty
|
||||
LLAMA_API llama_pos llama_kv_self_seq_pos_max(
|
||||
struct llama_context * ctx,
|
||||
|
@ -694,14 +697,15 @@ extern "C" {
|
|||
// Defragment the KV cache
|
||||
// This will be applied:
|
||||
// - lazily on next llama_decode()
|
||||
// - explicitly with llama_kv_self_update()
|
||||
LLAMA_API void llama_kv_self_defrag(struct llama_context * ctx);
|
||||
LLAMA_API DEPRECATED(void llama_kv_self_defrag(struct llama_context * ctx),
|
||||
"simply remove this call, the context will automatically decide when to do a defragmentation based on 'defrag_thold'");
|
||||
|
||||
// Check if the context supports KV cache shifting
|
||||
LLAMA_API bool llama_kv_self_can_shift(const struct llama_context * ctx);
|
||||
|
||||
// Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
|
||||
LLAMA_API void llama_kv_self_update(struct llama_context * ctx);
|
||||
LLAMA_API DEPRECATED(void llama_kv_self_update(struct llama_context * ctx),
|
||||
"simply remove this call, updates are applied lazily on the next llama_decode()");
|
||||
|
||||
//
|
||||
// State / sessions
|
||||
|
|
|
@ -1013,7 +1013,7 @@ Current version indicated by LITEVER below.
|
|||
}
|
||||
.scenariogrid
|
||||
{
|
||||
height: 260px;
|
||||
height: 240px;
|
||||
overflow-y: auto;
|
||||
margin-top: 4px;
|
||||
padding: 8px;
|
||||
|
@ -9081,7 +9081,7 @@ Current version indicated by LITEVER below.
|
|||
}
|
||||
|
||||
document.getElementById("scenariogrid").innerHTML = scenarios;
|
||||
document.getElementById("scenariodesc").innerText = "No Scenario Selected";
|
||||
document.getElementById("scenariodesc").innerText = "No Scenario Selected. Scroll for more options.";
|
||||
togglescenarioautopick();
|
||||
}
|
||||
|
||||
|
|
|
@ -191,6 +191,7 @@ class load_model_inputs(ctypes.Structure):
|
|||
("quant_k", ctypes.c_int),
|
||||
("quant_v", ctypes.c_int),
|
||||
("check_slowness", ctypes.c_bool),
|
||||
("highpriority", ctypes.c_bool),
|
||||
("swa_support", ctypes.c_bool),
|
||||
("lora_multiplier", ctypes.c_float),
|
||||
("quiet", ctypes.c_bool),
|
||||
|
@ -1247,6 +1248,7 @@ def load_model(model_filename):
|
|||
inputs.override_kv = args.overridekv.encode("UTF-8") if args.overridekv else "".encode("UTF-8")
|
||||
inputs.override_tensors = args.overridetensors.encode("UTF-8") if args.overridetensors else "".encode("UTF-8")
|
||||
inputs.check_slowness = (not args.highpriority and os.name == 'nt' and 'Intel' in platform.processor())
|
||||
inputs.highpriority = args.highpriority
|
||||
inputs.swa_support = args.useswa
|
||||
inputs = set_backend_props(inputs)
|
||||
ret = handle.load_model(inputs)
|
||||
|
|
|
@ -15,24 +15,31 @@ llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) {
|
|||
break;
|
||||
}
|
||||
}
|
||||
ubatch_token.resize(!has_embd ? n_ubatch : 0);
|
||||
ubatch_embd.resize(has_embd ? n_embd * n_ubatch : 0);
|
||||
ubatch_pos.resize(n_ubatch);
|
||||
ubatch_n_seq_id.resize(n_ubatch);
|
||||
ubatch_seq_id.resize(n_ubatch);
|
||||
ubatch_output.resize(n_ubatch);
|
||||
|
||||
udatas.push_back({});
|
||||
|
||||
auto & udata = udatas.back();
|
||||
|
||||
udata.token.resize(!has_embd ? n_ubatch : 0);
|
||||
udata.embd.resize(has_embd ? n_embd * n_ubatch : 0);
|
||||
udata.pos.resize(n_ubatch);
|
||||
udata.n_seq_id.resize(n_ubatch);
|
||||
udata.seq_id.resize(n_ubatch);
|
||||
udata.output.resize(n_ubatch);
|
||||
|
||||
llama_ubatch ubatch = {
|
||||
/*equal_seqs =*/ true,
|
||||
/*n_tokens =*/ 0,
|
||||
/*n_seq_tokens =*/ 0,
|
||||
/*n_seqs =*/ 0,
|
||||
/*token =*/ !has_embd ? ubatch_token.data() : nullptr,
|
||||
/*embd =*/ has_embd ? ubatch_embd.data() : nullptr,
|
||||
/*pos =*/ ubatch_pos.data(),
|
||||
/*n_seq_id =*/ ubatch_n_seq_id.data(),
|
||||
/*seq_id =*/ ubatch_seq_id.data(),
|
||||
/*output =*/ ubatch_output.data(),
|
||||
/*token =*/ !has_embd ? udata.token.data() : nullptr,
|
||||
/*embd =*/ has_embd ? udata.embd.data() : nullptr,
|
||||
/*pos =*/ udata.pos.data(),
|
||||
/*n_seq_id =*/ udata.n_seq_id.data(),
|
||||
/*seq_id =*/ udata.seq_id.data(),
|
||||
/*output =*/ udata.output.data(),
|
||||
};
|
||||
|
||||
return ubatch;
|
||||
}
|
||||
|
||||
|
|
|
@ -11,15 +11,15 @@ struct llama_ubatch {
|
|||
bool equal_seqs;
|
||||
// TODO: whole_seqs for embeddings?
|
||||
|
||||
uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
|
||||
uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
|
||||
uint32_t n_seq_tokens; // tokens per sequence
|
||||
uint32_t n_seqs;
|
||||
|
||||
llama_token * token; // [n_tokens]
|
||||
float * embd; // [n_embd, n_tokens]
|
||||
llama_pos * pos; // [n_tokens]
|
||||
int32_t * n_seq_id; // [n_seqs]
|
||||
llama_seq_id ** seq_id; // [n_seqs]
|
||||
int32_t * n_seq_id; // [n_seqs] // TODO: remove, should belong to only 1 sequence
|
||||
llama_seq_id ** seq_id; // [n_seqs] // TODO: become llama_seq_id * seq_id;
|
||||
int8_t * output; // [n_tokens]
|
||||
};
|
||||
|
||||
|
@ -49,13 +49,18 @@ struct llama_sbatch {
|
|||
|
||||
const llama_batch * batch = nullptr;
|
||||
|
||||
// buffers for the ubatch
|
||||
std::vector<llama_token> ubatch_token;
|
||||
std::vector<float> ubatch_embd;
|
||||
std::vector<llama_pos> ubatch_pos;
|
||||
std::vector<int32_t> ubatch_n_seq_id;
|
||||
std::vector<llama_seq_id *> ubatch_seq_id;
|
||||
std::vector<int8_t> ubatch_output;
|
||||
// buffers for the ubatches
|
||||
// TODO: very hacky, this needs a complete rework
|
||||
struct ubatch_data {
|
||||
std::vector<llama_token> token;
|
||||
std::vector<float> embd;
|
||||
std::vector<llama_pos> pos;
|
||||
std::vector<int32_t> n_seq_id;
|
||||
std::vector<llama_seq_id *> seq_id;
|
||||
std::vector<int8_t> output;
|
||||
};
|
||||
|
||||
std::vector<ubatch_data> udatas;
|
||||
|
||||
llama_ubatch reserve_ubatch(size_t n_ubatch, bool has_embd = false);
|
||||
|
||||
|
|
|
@ -6,9 +6,10 @@
|
|||
#include "llama-model.h"
|
||||
#include "llama-kv-cache.h"
|
||||
|
||||
#include <cstring>
|
||||
#include <stdexcept>
|
||||
#include <cinttypes>
|
||||
#include <cstring>
|
||||
#include <limits>
|
||||
#include <stdexcept>
|
||||
|
||||
//
|
||||
// llama_context
|
||||
|
@ -122,6 +123,11 @@ llama_context::llama_context(
|
|||
__func__, n_ctx_per_seq, hparams.n_ctx_train);
|
||||
}
|
||||
|
||||
if (!params.swa_full && cparams.n_seq_max > 1) {
|
||||
LLAMA_LOG_WARN("%s: requested n_seq_max (%u) > 1, but swa_full is not enabled -- performance may be degraded: %s\n",
|
||||
__func__, cparams.n_seq_max, "https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573");
|
||||
}
|
||||
|
||||
if (!hparams.vocab_only) {
|
||||
// GPU backends
|
||||
for (auto * dev : model.devices) {
|
||||
|
@ -259,15 +265,9 @@ llama_context::llama_context(
|
|||
|
||||
// reserve worst-case graph
|
||||
if (!hparams.vocab_only && memory) {
|
||||
const uint32_t n_seqs = 1; // TODO: worst-case number of sequences
|
||||
const uint32_t n_seqs = cparams.n_seq_max;
|
||||
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
||||
|
||||
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
|
||||
|
||||
// restore later
|
||||
// TODO: something cleaner
|
||||
const auto n_outputs_save = n_outputs;
|
||||
|
||||
LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
|
||||
|
||||
int n_splits_pp = -1;
|
||||
|
@ -279,23 +279,17 @@ llama_context::llama_context(
|
|||
// simulate full KV cache
|
||||
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
||||
|
||||
kv_self->set_full();
|
||||
const auto kv_state = kv_self->init_full();
|
||||
if (!kv_state) {
|
||||
throw std::runtime_error("failed to initialize KV cache");
|
||||
}
|
||||
|
||||
cross.v_embd.clear();
|
||||
|
||||
// reserve pp graph first so that buffers are only allocated once
|
||||
{
|
||||
llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
||||
|
||||
// max number of outputs
|
||||
n_outputs = ubatch_pp.n_tokens;
|
||||
|
||||
LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_pp.n_tokens, ubatch_pp.n_seqs);
|
||||
|
||||
auto * gf = graph_init();
|
||||
graph_build(ctx_compute.get(), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT);
|
||||
|
||||
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
|
||||
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get());
|
||||
if (!gf) {
|
||||
throw std::runtime_error("failed to allocate compute pp buffers");
|
||||
}
|
||||
|
||||
|
@ -305,16 +299,8 @@ llama_context::llama_context(
|
|||
|
||||
// reserve with tg graph to get the number of splits and nodes
|
||||
{
|
||||
llama_ubatch ubatch_tg = { true, 1, 1, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
||||
|
||||
n_outputs = ubatch_tg.n_tokens;
|
||||
|
||||
LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_tg.n_tokens, ubatch_tg.n_seqs);
|
||||
|
||||
auto * gf = graph_init();
|
||||
graph_build(ctx_compute.get(), gf, ubatch_tg, LLM_GRAPH_TYPE_DEFAULT);
|
||||
|
||||
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
|
||||
auto * gf = graph_reserve(1, 1, 1, kv_state.get());
|
||||
if (!gf) {
|
||||
throw std::runtime_error("failed to allocate compute tg buffers");
|
||||
}
|
||||
|
||||
|
@ -324,22 +310,12 @@ llama_context::llama_context(
|
|||
|
||||
// reserve again with pp graph to avoid ggml-alloc reallocations during inference
|
||||
{
|
||||
llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
||||
|
||||
n_outputs = ubatch_pp.n_tokens;
|
||||
|
||||
LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_pp.n_tokens, ubatch_pp.n_seqs);
|
||||
|
||||
auto * gf = graph_init();
|
||||
graph_build(ctx_compute.get(), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT);
|
||||
|
||||
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
|
||||
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get());
|
||||
if (!gf) {
|
||||
throw std::runtime_error("failed to allocate compute pp buffers");
|
||||
}
|
||||
}
|
||||
|
||||
n_outputs = n_outputs_save;
|
||||
|
||||
for (size_t i = 0; i < backend_ptrs.size(); ++i) {
|
||||
ggml_backend_t backend = backend_ptrs[i];
|
||||
ggml_backend_buffer_type_t buft = backend_buft[i];
|
||||
|
@ -453,36 +429,33 @@ const llama_kv_cache * llama_context::get_kv_self() const {
|
|||
return kv_self;
|
||||
}
|
||||
|
||||
void llama_context::kv_self_update() {
|
||||
bool need_reserve = false;
|
||||
bool llama_context::kv_self_update() {
|
||||
if (!memory) {
|
||||
return false;
|
||||
}
|
||||
|
||||
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
||||
|
||||
need_reserve = kv_self->update(*this);
|
||||
|
||||
// reserve a worst case graph if needed
|
||||
if (need_reserve) {
|
||||
// LLAMA_LOG_DEBUG("%s: reserving a worst case graph\n", __func__);
|
||||
|
||||
// build worst-case graph
|
||||
uint32_t n_seqs = 1; // TODO: worst-case number of sequences
|
||||
uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
||||
|
||||
// simulate full KV cache
|
||||
kv_self->set_full();
|
||||
|
||||
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
|
||||
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
||||
|
||||
auto * gf = graph_init();
|
||||
graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT);
|
||||
|
||||
// initialize scheduler with the worst-case graph
|
||||
ggml_backend_sched_reset(sched.get());
|
||||
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
|
||||
}
|
||||
if (!kv_self->update(*this)) {
|
||||
// no updates have been performed
|
||||
return false;
|
||||
}
|
||||
|
||||
// if the KV cache did any computation, we have to reserve a new worst-case graph
|
||||
const auto kv_state = kv_self->init_full();
|
||||
if (!kv_state) {
|
||||
throw std::runtime_error("failed to initialize KV cache");
|
||||
}
|
||||
|
||||
const uint32_t n_seqs = cparams.n_seq_max;
|
||||
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
||||
|
||||
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get());
|
||||
if (!gf) {
|
||||
LLAMA_LOG_ERROR("%s: failed to reserve graph after the KV cache update\n", __func__);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
enum llama_pooling_type llama_context::pooling_type() const {
|
||||
|
@ -676,6 +649,49 @@ bool llama_context::apply_adapter_cvec(
|
|||
return cvec.apply(model, data, len, n_embd, il_start, il_end);
|
||||
}
|
||||
|
||||
llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_state_i * mstate, ggml_status & ret) {
|
||||
if (mstate && !mstate->apply()) {
|
||||
LLAMA_LOG_ERROR("%s: failed to apply memory state\n", __func__);
|
||||
ret = GGML_STATUS_FAILED;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto * gf = graph_init();
|
||||
if (!gf) {
|
||||
LLAMA_LOG_ERROR("%s: failed to initialize graph\n", __func__);
|
||||
ret = GGML_STATUS_FAILED;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype, mstate);
|
||||
if (!res) {
|
||||
LLAMA_LOG_ERROR("%s: failed to build graph\n", __func__);
|
||||
ret = GGML_STATUS_FAILED;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
|
||||
|
||||
if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__);
|
||||
ret = GGML_STATUS_ALLOC_FAILED;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
res->set_inputs(&ubatch);
|
||||
|
||||
const auto status = graph_compute(gf, ubatch.n_tokens > 1);
|
||||
if (status != GGML_STATUS_SUCCESS) {
|
||||
LLAMA_LOG_ERROR("%s: failed to compute graph, compute status: %d\n", __func__, status);
|
||||
ret = status;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ret = GGML_STATUS_SUCCESS;
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
int llama_context::encode(llama_batch & inp_batch) {
|
||||
if (inp_batch.n_tokens == 0) {
|
||||
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
|
||||
|
@ -737,8 +753,6 @@ int llama_context::encode(llama_batch & inp_batch) {
|
|||
|
||||
n_outputs = n_tokens;
|
||||
|
||||
//batch_manager->prepare(ubatch);
|
||||
|
||||
ggml_backend_sched_reset(sched.get());
|
||||
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
|
||||
|
||||
|
@ -749,26 +763,18 @@ int llama_context::encode(llama_batch & inp_batch) {
|
|||
// ref: https://github.com/ggml-org/llama.cpp/pull/12181#issuecomment-2730451223
|
||||
cparams.causal_attn = false;
|
||||
|
||||
auto * gf = graph_init();
|
||||
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_ENCODER);
|
||||
|
||||
ggml_backend_sched_alloc_graph(sched.get(), gf);
|
||||
|
||||
res->set_inputs(&ubatch);
|
||||
ggml_status status;
|
||||
const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status);
|
||||
|
||||
cparams.causal_attn = causal_attn_org;
|
||||
|
||||
const auto compute_status = graph_compute(gf, n_tokens > 1);
|
||||
switch (compute_status) {
|
||||
case GGML_STATUS_SUCCESS:
|
||||
break;
|
||||
case GGML_STATUS_ABORTED:
|
||||
return 2;
|
||||
case GGML_STATUS_ALLOC_FAILED:
|
||||
return -2;
|
||||
case GGML_STATUS_FAILED:
|
||||
default:
|
||||
return -3;
|
||||
if (!res) {
|
||||
switch (status) {
|
||||
case GGML_STATUS_ABORTED: return 2;
|
||||
case GGML_STATUS_ALLOC_FAILED: return -2;
|
||||
case GGML_STATUS_FAILED: return -3;
|
||||
case GGML_STATUS_SUCCESS: GGML_ABORT("should not happen");
|
||||
}
|
||||
}
|
||||
|
||||
auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd();
|
||||
|
@ -889,8 +895,6 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|||
const int64_t n_tokens_all = batch.n_tokens;
|
||||
const int64_t n_embd = hparams.n_embd;
|
||||
|
||||
llama_kv_cache_guard kv_guard(kv_self);
|
||||
|
||||
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
|
||||
|
||||
// TODO: move the validation to the llama_batch_allocr
|
||||
|
@ -936,7 +940,48 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|||
n_outputs_all = 1;
|
||||
}
|
||||
|
||||
llama_sbatch sbatch = kv_self->sbatch_init(batch, /* logits_all */ n_outputs_all == n_tokens_all);
|
||||
// handle any pending defrags/shifts
|
||||
kv_self_update();
|
||||
|
||||
llama_memory_state_ptr kv_state;
|
||||
|
||||
bool did_defrag = false;
|
||||
|
||||
while (true) {
|
||||
kv_state = kv_self->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
|
||||
if (!kv_state) {
|
||||
return -2;
|
||||
}
|
||||
|
||||
switch (kv_state->get_status()) {
|
||||
case LLAMA_MEMORY_STATUS_SUCCESS:
|
||||
{
|
||||
} break;
|
||||
case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
|
||||
{
|
||||
if (!did_defrag) {
|
||||
did_defrag = true;
|
||||
|
||||
kv_self->defrag_sched(-1.0f);
|
||||
if (kv_self_update()) {
|
||||
LLAMA_LOG_DEBUG("%s: failed to init batch of size %d, retrying after defrag\n", __func__, batch.n_tokens);
|
||||
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
LLAMA_LOG_WARN("%s: failed to find KV cache slot for batch of size %d\n", __func__, batch.n_tokens);
|
||||
|
||||
return 1;
|
||||
}
|
||||
case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
|
||||
{
|
||||
return -2;
|
||||
}
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
// reserve output buffer
|
||||
if (output_reserve(n_outputs_all) < n_outputs_all) {
|
||||
|
@ -944,13 +989,10 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|||
return -2;
|
||||
};
|
||||
|
||||
// handle any pending defrags/shifts
|
||||
kv_self_update();
|
||||
|
||||
int64_t n_outputs_prev = 0;
|
||||
|
||||
while (sbatch.n_tokens > 0) {
|
||||
llama_ubatch ubatch = kv_self->ubatch_next(sbatch, cparams.n_ubatch, embd_pooled);
|
||||
do {
|
||||
const auto & ubatch = kv_state->get_ubatch();
|
||||
|
||||
// count the outputs in this u_batch
|
||||
{
|
||||
|
@ -969,33 +1011,37 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|||
n_outputs = n_outputs_new;
|
||||
}
|
||||
|
||||
// find KV slot
|
||||
if (!kv_self->find_slot(ubatch)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
ggml_backend_sched_reset(sched.get());
|
||||
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
|
||||
|
||||
auto * gf = graph_init();
|
||||
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DECODER);
|
||||
ggml_status status;
|
||||
const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, kv_state.get(), status);
|
||||
|
||||
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
|
||||
if (!res) {
|
||||
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
|
||||
llama_pos pos_min[LLAMA_MAX_PARALLEL_SEQUENCES] = { std::numeric_limits<llama_pos>::max() };
|
||||
|
||||
ggml_backend_sched_alloc_graph(sched.get(), gf);
|
||||
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
|
||||
const auto & seq_id = ubatch.seq_id[i][0];
|
||||
|
||||
res->set_inputs(&ubatch);
|
||||
pos_min[seq_id] = std::min(pos_min[seq_id], ubatch.pos[i]);
|
||||
}
|
||||
|
||||
const auto compute_status = graph_compute(gf, ubatch.n_tokens > 1);
|
||||
if (compute_status != GGML_STATUS_SUCCESS) {
|
||||
switch (compute_status) {
|
||||
case GGML_STATUS_ABORTED:
|
||||
return 2;
|
||||
case GGML_STATUS_ALLOC_FAILED:
|
||||
return -2;
|
||||
case GGML_STATUS_FAILED:
|
||||
default:
|
||||
return -3;
|
||||
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
|
||||
if (pos_min[s] == std::numeric_limits<llama_pos>::max()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
LLAMA_LOG_WARN("%s: removing KV cache entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]);
|
||||
|
||||
llama_kv_self_seq_rm(this, s, pos_min[s], -1);
|
||||
}
|
||||
|
||||
switch (status) {
|
||||
case GGML_STATUS_ABORTED: return 2;
|
||||
case GGML_STATUS_ALLOC_FAILED: return -2;
|
||||
case GGML_STATUS_FAILED: return -3;
|
||||
case GGML_STATUS_SUCCESS: GGML_ABORT("should not happen");
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1082,10 +1128,7 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|||
}
|
||||
|
||||
n_outputs_prev += n_outputs;
|
||||
}
|
||||
|
||||
// finalize the batch processing
|
||||
kv_guard.commit();
|
||||
} while (kv_state->next());
|
||||
|
||||
// set to total number of outputs in the batch, for use in llama_get_logits_ith
|
||||
n_outputs = n_outputs_all;
|
||||
|
@ -1094,7 +1137,7 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|||
{
|
||||
bool sorted_output = true;
|
||||
|
||||
auto & out_ids = sbatch.out_ids;
|
||||
auto & out_ids = kv_state->out_ids();
|
||||
|
||||
GGML_ASSERT(out_ids.size() == (size_t) n_outputs_all);
|
||||
|
||||
|
@ -1254,11 +1297,52 @@ ggml_cgraph * llama_context::graph_init() {
|
|||
return ggml_new_graph_custom(ctx_compute.get(), graph_max_nodes(), false);
|
||||
}
|
||||
|
||||
ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_state_i * mstate) {
|
||||
//LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs);
|
||||
|
||||
if (n_tokens % n_seqs != 0) {
|
||||
n_tokens = (n_tokens / n_seqs) * n_seqs;
|
||||
n_outputs = std::min(n_outputs, n_tokens);
|
||||
|
||||
//LLAMA_LOG_DEBUG("%s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n", __func__, n_tokens, n_seqs, n_outputs);
|
||||
}
|
||||
|
||||
// store the n_outputs as it is, and restore it afterwards
|
||||
// TODO: not sure if needed, might simplify in the future by removing this
|
||||
const auto save_n_outputs = this->n_outputs;
|
||||
|
||||
this->n_outputs = n_outputs;
|
||||
|
||||
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
|
||||
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
||||
|
||||
auto * gf = graph_init();
|
||||
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mstate);
|
||||
|
||||
this->n_outputs = save_n_outputs;
|
||||
|
||||
if (!res) {
|
||||
LLAMA_LOG_ERROR("%s: failed to build worst-case graph\n", __func__);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ggml_backend_sched_reset(sched.get());
|
||||
|
||||
// initialize scheduler with the specified graph
|
||||
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return gf;
|
||||
}
|
||||
|
||||
llm_graph_result_ptr llama_context::graph_build(
|
||||
ggml_context * ctx,
|
||||
ggml_cgraph * gf,
|
||||
const llama_ubatch & ubatch,
|
||||
llm_graph_type gtype) {
|
||||
ggml_context * ctx,
|
||||
ggml_cgraph * gf,
|
||||
const llama_ubatch & ubatch,
|
||||
llm_graph_type gtype,
|
||||
const llama_memory_state_i * mstate) {
|
||||
return model.build_graph(
|
||||
{
|
||||
/*.ctx =*/ ctx,
|
||||
|
@ -1270,7 +1354,7 @@ llm_graph_result_ptr llama_context::graph_build(
|
|||
/*.backend_cpu =*/ backend_cpu,
|
||||
/*.cvec =*/ &cvec,
|
||||
/*.loras =*/ &loras,
|
||||
/*.memory =*/ memory.get(),
|
||||
/*.mstate =*/ mstate,
|
||||
/*.cross =*/ &cross,
|
||||
/*.n_outputs =*/ n_outputs,
|
||||
/*.cb =*/ graph_get_cb(),
|
||||
|
@ -1951,7 +2035,6 @@ void llama_context::opt_epoch_iter(
|
|||
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
||||
|
||||
kv_self->clear();
|
||||
llama_kv_cache_guard kv_guard(kv_self);
|
||||
|
||||
for (uint32_t pos_ctx = 0; pos_ctx < n_ctx; pos_ctx += n_batch) {
|
||||
batch.n_tokens = n_batch;
|
||||
|
@ -1974,7 +2057,11 @@ void llama_context::opt_epoch_iter(
|
|||
|
||||
int64_t n_outputs_all = n_tokens_all;
|
||||
|
||||
llama_sbatch sbatch = kv_self->sbatch_init(batch, /*logits_all =*/ true);
|
||||
auto kv_state = kv_self->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ true);
|
||||
if (!kv_state || kv_state->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
|
||||
LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
|
||||
break;
|
||||
}
|
||||
|
||||
// reserve output buffer
|
||||
if (output_reserve(n_outputs_all) < n_outputs_all) {
|
||||
|
@ -1982,20 +2069,19 @@ void llama_context::opt_epoch_iter(
|
|||
GGML_ABORT("TODO: handle this error");
|
||||
};
|
||||
|
||||
for (uint32_t pos_batch = 0; pos_batch < n_batch; pos_batch += n_ubatch) {
|
||||
llama_ubatch ubatch = kv_self->ubatch_next(sbatch, cparams.n_ubatch, embd_pooled);
|
||||
uint32_t pos_batch = 0;
|
||||
do {
|
||||
const auto & ubatch = kv_state->get_ubatch();
|
||||
|
||||
n_outputs = ubatch.n_tokens;
|
||||
|
||||
// TODO: not sure if this is needed
|
||||
if (!kv_self->find_slot(ubatch)) {
|
||||
LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
|
||||
|
||||
GGML_ABORT("TODO: handle this error");
|
||||
if (!kv_state->apply()) {
|
||||
LLAMA_LOG_ERROR("%s: failed to update the memory state\n", __func__);
|
||||
break;
|
||||
}
|
||||
|
||||
auto * gf = graph_init();
|
||||
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT);
|
||||
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, kv_state.get());
|
||||
|
||||
struct ggml_context * ctx_compute_opt;
|
||||
{
|
||||
|
@ -2010,6 +2096,7 @@ void llama_context::opt_epoch_iter(
|
|||
}
|
||||
ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_tokens(), res->get_logits());
|
||||
ggml_opt_alloc(opt_ctx, train);
|
||||
|
||||
res->set_inputs(&ubatch);
|
||||
{
|
||||
struct ggml_tensor * labels = ggml_opt_labels(opt_ctx);
|
||||
|
@ -2027,10 +2114,10 @@ void llama_context::opt_epoch_iter(
|
|||
callback(train, opt_ctx, dataset, result, idata_in_loop + (pos_ctx + pos_batch)/n_ubatch + 1, ndata_in_loop, t_loop_start);
|
||||
}
|
||||
ggml_free(ctx_compute_opt);
|
||||
}
|
||||
}
|
||||
|
||||
kv_guard.commit();
|
||||
pos_batch += ubatch.n_tokens;
|
||||
} while (kv_state->next());
|
||||
}
|
||||
}
|
||||
|
||||
void llama_context::opt_epoch(
|
||||
|
@ -2194,6 +2281,7 @@ llama_kv_cache * llama_get_kv_self(llama_context * ctx) {
|
|||
return ctx->get_kv_self();
|
||||
}
|
||||
|
||||
// deprecated
|
||||
void llama_kv_self_update(llama_context * ctx) {
|
||||
ctx->kv_self_update();
|
||||
}
|
||||
|
@ -2448,6 +2536,7 @@ llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
|
|||
return kv->seq_pos_max(seq_id);
|
||||
}
|
||||
|
||||
// deprecated
|
||||
void llama_kv_self_defrag(llama_context * ctx) {
|
||||
auto * kv = ctx->get_kv_self();
|
||||
if (!kv) {
|
||||
|
@ -2589,22 +2678,8 @@ int32_t llama_encode(
|
|||
int32_t llama_decode(
|
||||
llama_context * ctx,
|
||||
llama_batch batch) {
|
||||
int ret = ctx->decode(batch);
|
||||
|
||||
// defrag and try again
|
||||
// TODO: distinguish return code when we are sure that even after defrag there is no space available
|
||||
if (ret == 1) {
|
||||
llama_kv_self_defrag(ctx);
|
||||
ret = ctx->decode(batch);
|
||||
|
||||
if (ret == 1) {
|
||||
LLAMA_LOG_WARN("%s: failed to find KV cache slot for batch of size %d\n", __func__, batch.n_tokens);
|
||||
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
|
||||
if (ret != 0) {
|
||||
const int ret = ctx->decode(batch);
|
||||
if (ret != 0 && ret != 1) {
|
||||
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
|
||||
}
|
||||
|
||||
|
|
|
@ -18,6 +18,9 @@ struct llama_kv_cache;
|
|||
class llama_io_read_i;
|
||||
class llama_io_write_i;
|
||||
|
||||
class llama_memory_i;
|
||||
class llama_memory_state_i;
|
||||
|
||||
struct llama_context {
|
||||
// init scheduler and compute buffers, reserve worst-case graphs
|
||||
llama_context(
|
||||
|
@ -47,7 +50,9 @@ struct llama_context {
|
|||
llama_kv_cache * get_kv_self();
|
||||
const llama_kv_cache * get_kv_self() const;
|
||||
|
||||
void kv_self_update();
|
||||
// return true of the KV cache was updated
|
||||
// TODO: remove
|
||||
bool kv_self_update();
|
||||
|
||||
enum llama_pooling_type pooling_type() const;
|
||||
|
||||
|
@ -88,6 +93,16 @@ struct llama_context {
|
|||
int32_t il_start,
|
||||
int32_t il_end);
|
||||
|
||||
// process a single ubatch with a specific graph type
|
||||
// if memory_state is provided, it will be applied first to the context's memory
|
||||
// ret contains the status of the graph computation
|
||||
// returns nullptr only if ret != GGML_STATUS_SUCCESS
|
||||
llm_graph_result_ptr process_ubatch(
|
||||
const llama_ubatch & ubatch,
|
||||
llm_graph_type gtype,
|
||||
llama_memory_state_i * mstate,
|
||||
ggml_status & ret);
|
||||
|
||||
int encode(llama_batch & inp_batch);
|
||||
int decode(llama_batch & inp_batch);
|
||||
|
||||
|
@ -180,16 +195,18 @@ public:
|
|||
ggml_cgraph * graph_init();
|
||||
|
||||
// returns the result of ggml_backend_sched_graph_compute_async execution
|
||||
ggml_status graph_compute(
|
||||
ggml_cgraph * gf,
|
||||
bool batched);
|
||||
ggml_status graph_compute(ggml_cgraph * gf, bool batched);
|
||||
|
||||
// reserve a graph with a dummy ubatch of the specified size
|
||||
ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_state_i * mstate);
|
||||
|
||||
private:
|
||||
llm_graph_result_ptr graph_build(
|
||||
ggml_context * ctx,
|
||||
ggml_cgraph * gf,
|
||||
const llama_ubatch & ubatch,
|
||||
llm_graph_type gtype);
|
||||
ggml_context * ctx,
|
||||
ggml_cgraph * gf,
|
||||
const llama_ubatch & ubatch,
|
||||
llm_graph_type gtype,
|
||||
const llama_memory_state_i * mstate);
|
||||
|
||||
llm_graph_cb graph_get_cb() const;
|
||||
|
||||
|
|
|
@ -83,7 +83,7 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
|
|||
|
||||
void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) {
|
||||
if (pos_bucket) {
|
||||
kv_self->set_input_pos_bucket(pos_bucket, ubatch);
|
||||
kv_state->set_input_pos_bucket(pos_bucket, ubatch);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -234,7 +234,7 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
|
|||
void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
|
||||
GGML_UNUSED(ubatch);
|
||||
|
||||
const int64_t n_kv = kv_self->n;
|
||||
const int64_t n_kv = kv_state->get_n_kv();
|
||||
|
||||
if (s_copy) {
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
|
||||
|
@ -242,7 +242,7 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
|
|||
|
||||
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
|
||||
for (uint32_t i = 0; i < n_kv; ++i) {
|
||||
data[i] = kv_self->s_copy(i);
|
||||
data[i] = kv_state->s_copy(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -250,7 +250,7 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
|
|||
void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) {
|
||||
GGML_UNUSED(ubatch);
|
||||
|
||||
const int64_t n_kv = kv_self->n;
|
||||
const int64_t n_kv = kv_state->get_n_kv();
|
||||
|
||||
if (s_mask) {
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(s_mask->buffer));
|
||||
|
@ -258,7 +258,7 @@ void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) {
|
|||
|
||||
// clear unused states
|
||||
for (int i = 0; i < n_kv; ++i) {
|
||||
data[i] = kv_self->s_mask(i);
|
||||
data[i] = kv_state->s_mask(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -362,17 +362,17 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
|
|||
|
||||
void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
|
||||
if (self_kq_mask) {
|
||||
kv_self->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
||||
kv_state->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
||||
}
|
||||
}
|
||||
|
||||
void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
|
||||
if (self_kq_mask) {
|
||||
kv_self->get_kv_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
||||
kv_state->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
||||
}
|
||||
|
||||
if (self_kq_mask_swa) {
|
||||
kv_self->get_kv_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
|
||||
kv_state->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -448,7 +448,7 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
|
|||
backend_cpu (params.backend_cpu),
|
||||
cvec (params.cvec),
|
||||
loras (params.loras),
|
||||
memory (params.memory),
|
||||
mstate (params.mstate),
|
||||
cross (params.cross),
|
||||
cb_func (params.cb),
|
||||
res (std::make_unique<llm_graph_result>()) {
|
||||
|
@ -954,11 +954,11 @@ ggml_tensor * llm_graph_context::build_inp_cls() const {
|
|||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_inp_s_copy() const {
|
||||
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
|
||||
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_s_copy>(kv_self);
|
||||
auto inp = std::make_unique<llm_graph_input_s_copy>(kv_state);
|
||||
|
||||
const auto n_kv = kv_self->n;
|
||||
const auto n_kv = kv_state->get_n_kv();
|
||||
|
||||
auto & cur = inp->s_copy;
|
||||
|
||||
|
@ -971,11 +971,11 @@ ggml_tensor * llm_graph_context::build_inp_s_copy() const {
|
|||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_inp_s_mask() const {
|
||||
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
|
||||
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_s_mask>(kv_self);
|
||||
auto inp = std::make_unique<llm_graph_input_s_mask>(kv_state);
|
||||
|
||||
const auto n_kv = kv_self->n;
|
||||
const auto n_kv = kv_state->get_n_kv();
|
||||
|
||||
auto & cur = inp->s_mask;
|
||||
|
||||
|
@ -1025,11 +1025,11 @@ ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const {
|
|||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
|
||||
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
|
||||
const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate);
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, kv_self);
|
||||
auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, kv_state);
|
||||
|
||||
const auto n_kv = kv_self->get_n();
|
||||
const auto n_kv = kv_state->get_n_kv();
|
||||
|
||||
auto & cur = inp->pos_bucket;
|
||||
|
||||
|
@ -1231,14 +1231,14 @@ ggml_tensor * llm_graph_context::build_attn(
|
|||
}
|
||||
|
||||
llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const {
|
||||
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
|
||||
const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate);
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_self);
|
||||
auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_state);
|
||||
|
||||
{
|
||||
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
|
||||
|
||||
const auto n_kv = kv_self->get_n();
|
||||
const auto n_kv = kv_state->get_n_kv();
|
||||
|
||||
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
||||
|
@ -1268,19 +1268,19 @@ ggml_tensor * llm_graph_context::build_attn(
|
|||
ggml_build_forward_expand(gf, k_cur);
|
||||
ggml_build_forward_expand(gf, v_cur);
|
||||
|
||||
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
|
||||
const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate);
|
||||
|
||||
// store to KV cache
|
||||
{
|
||||
ggml_build_forward_expand(gf, kv_self->cpy_k(ctx0, k_cur, il));
|
||||
ggml_build_forward_expand(gf, kv_self->cpy_v(ctx0, v_cur, il));
|
||||
ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il));
|
||||
ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il));
|
||||
}
|
||||
|
||||
const auto & kq_mask = inp->get_kq_mask();
|
||||
|
||||
ggml_tensor * q = q_cur;
|
||||
ggml_tensor * k = kv_self->get_k(ctx0, il);
|
||||
ggml_tensor * v = kv_self->get_v(ctx0, il);
|
||||
ggml_tensor * k = kv_state->get_k(ctx0, il);
|
||||
ggml_tensor * v = kv_state->get_v(ctx0, il);
|
||||
|
||||
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
||||
cb(cur, "kqv_out", il);
|
||||
|
@ -1301,12 +1301,12 @@ ggml_tensor * llm_graph_context::build_attn(
|
|||
}
|
||||
|
||||
llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
|
||||
const llama_kv_cache_unified_iswa * kv_self = static_cast<const llama_kv_cache_unified_iswa *>(memory);
|
||||
const auto * kv_state = static_cast<const llama_kv_cache_unified_iswa_state *>(mstate);
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, kv_self);
|
||||
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, kv_state);
|
||||
|
||||
{
|
||||
const auto n_kv = kv_self->get_kv_base()->get_n();
|
||||
const auto n_kv = kv_state->get_base()->get_n_kv();
|
||||
|
||||
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
||||
|
@ -1318,7 +1318,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
|
|||
{
|
||||
GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
|
||||
|
||||
const auto n_kv = kv_self->get_kv_swa()->get_n();
|
||||
const auto n_kv = kv_state->get_swa()->get_n_kv();
|
||||
|
||||
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
|
||||
|
@ -1348,23 +1348,23 @@ ggml_tensor * llm_graph_context::build_attn(
|
|||
ggml_build_forward_expand(gf, k_cur);
|
||||
ggml_build_forward_expand(gf, v_cur);
|
||||
|
||||
const auto * kv_state_iswa = static_cast<const llama_kv_cache_unified_iswa_state *>(mstate);
|
||||
|
||||
const bool is_swa = hparams.is_swa(il);
|
||||
|
||||
const llama_kv_cache_unified_iswa * kv_self = static_cast<const llama_kv_cache_unified_iswa *>(memory);
|
||||
|
||||
const auto * kv = is_swa ? kv_self->get_kv_swa() : kv_self->get_kv_base();
|
||||
const auto * kv_state = is_swa ? kv_state_iswa->get_swa() : kv_state_iswa->get_base();
|
||||
|
||||
// store to KV cache
|
||||
{
|
||||
ggml_build_forward_expand(gf, kv->cpy_k(ctx0, k_cur, il));
|
||||
ggml_build_forward_expand(gf, kv->cpy_v(ctx0, v_cur, il));
|
||||
ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il));
|
||||
ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il));
|
||||
}
|
||||
|
||||
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
|
||||
|
||||
ggml_tensor * q = q_cur;
|
||||
ggml_tensor * k = kv->get_k(ctx0, il);
|
||||
ggml_tensor * v = kv->get_v(ctx0, il);
|
||||
ggml_tensor * k = kv_state->get_k(ctx0, il);
|
||||
ggml_tensor * v = kv_state->get_v(ctx0, il);
|
||||
|
||||
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
||||
cb(cur, "kqv_out", il);
|
||||
|
@ -1446,12 +1446,12 @@ ggml_tensor * llm_graph_context::build_copy_mask_state(
|
|||
ggml_tensor * state_mask,
|
||||
int32_t n_state,
|
||||
int32_t n_seqs) const {
|
||||
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
|
||||
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
|
||||
|
||||
const auto n_kv = kv_self->n;
|
||||
const auto kv_head = kv_self->head;
|
||||
const auto n_kv = kv_state->get_n_kv();
|
||||
const auto kv_head = kv_state->get_head();
|
||||
|
||||
ggml_tensor * states = ggml_reshape_2d(ctx0, s, n_state, kv_self->size);
|
||||
ggml_tensor * states = ggml_reshape_2d(ctx0, s, n_state, kv_state->get_size());
|
||||
|
||||
// copy states
|
||||
// NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
|
||||
|
@ -1478,13 +1478,13 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
|
|||
ggml_tensor * state_mask,
|
||||
const llama_ubatch & ubatch,
|
||||
int il) const {
|
||||
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
|
||||
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
|
||||
|
||||
const auto token_shift_count = hparams.token_shift_count;
|
||||
|
||||
const int64_t n_seqs = ubatch.n_seqs;
|
||||
|
||||
ggml_tensor * token_shift_all = kv_self->k_l[il];
|
||||
ggml_tensor * token_shift_all = kv_state->get_k_l(il);
|
||||
|
||||
ggml_tensor * token_shift = build_copy_mask_state(
|
||||
gf, token_shift_all, state_copy, state_mask,
|
||||
|
@ -1499,19 +1499,19 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
|
|||
ggml_tensor * token_shift,
|
||||
const llama_ubatch & ubatch,
|
||||
int il) const {
|
||||
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
|
||||
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
|
||||
|
||||
const auto token_shift_count = hparams.token_shift_count;
|
||||
const auto n_embd = hparams.n_embd;
|
||||
|
||||
const int64_t n_seqs = ubatch.n_seqs;
|
||||
|
||||
const auto kv_head = kv_self->head;
|
||||
const auto kv_head = kv_state->get_head();
|
||||
|
||||
return ggml_cpy(
|
||||
ctx0,
|
||||
ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0),
|
||||
ggml_view_1d(ctx0, kv_self->k_l[il], hparams.n_embd_k_s() * n_seqs, hparams.n_embd_k_s() * kv_head * ggml_element_size(kv_self->k_l[il]))
|
||||
ggml_view_1d(ctx0, kv_state->get_k_l(il), hparams.n_embd_k_s()*n_seqs, hparams.n_embd_k_s()*kv_head*ggml_element_size(kv_state->get_k_l(il)))
|
||||
);
|
||||
}
|
||||
|
||||
|
|
|
@ -17,10 +17,11 @@ struct ggml_tensor;
|
|||
struct llama_ubatch;
|
||||
struct llama_cparams;
|
||||
|
||||
class llama_memory_i;
|
||||
class llama_kv_cache_unified;
|
||||
class llama_kv_cache_unified_iswa;
|
||||
class llama_kv_cache_recurrent;
|
||||
class llama_memory_state_i;
|
||||
|
||||
class llama_kv_cache_unified_state;
|
||||
class llama_kv_cache_unified_iswa_state;
|
||||
class llama_kv_cache_recurrent_state;
|
||||
|
||||
// certain models (typically multi-modal) can produce different types of graphs
|
||||
enum llm_graph_type {
|
||||
|
@ -133,7 +134,7 @@ class llm_graph_input_pos_bucket_kv : public llm_graph_input_i {
|
|||
public:
|
||||
llm_graph_input_pos_bucket_kv(
|
||||
const llama_hparams & hparams,
|
||||
const llama_kv_cache_unified * kv_self) : hparams(hparams), kv_self(kv_self) {}
|
||||
const llama_kv_cache_unified_state * kv_state) : hparams(hparams), kv_state(kv_state) {}
|
||||
virtual ~llm_graph_input_pos_bucket_kv() = default;
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
|
@ -141,7 +142,7 @@ public:
|
|||
ggml_tensor * pos_bucket = nullptr; // I32 [n_kv, n_batch]
|
||||
|
||||
const llama_hparams & hparams;
|
||||
const llama_kv_cache_unified * kv_self;
|
||||
const llama_kv_cache_unified_state * kv_state;
|
||||
};
|
||||
|
||||
class llm_graph_input_out_ids : public llm_graph_input_i {
|
||||
|
@ -188,26 +189,26 @@ public:
|
|||
|
||||
class llm_graph_input_s_copy : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_s_copy(const llama_kv_cache_recurrent * kv_self) : kv_self(kv_self) {}
|
||||
llm_graph_input_s_copy(const llama_kv_cache_recurrent_state * kv_state) : kv_state(kv_state) {}
|
||||
virtual ~llm_graph_input_s_copy() = default;
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
|
||||
ggml_tensor * s_copy; // I32 [kv_size]
|
||||
|
||||
const llama_kv_cache_recurrent * kv_self;
|
||||
const llama_kv_cache_recurrent_state * kv_state;
|
||||
};
|
||||
|
||||
class llm_graph_input_s_mask : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_s_mask(const llama_kv_cache_recurrent * kv_self) : kv_self(kv_self) {}
|
||||
llm_graph_input_s_mask(const llama_kv_cache_recurrent_state * kv_state) : kv_state(kv_state) {}
|
||||
virtual ~llm_graph_input_s_mask() = default;
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
|
||||
ggml_tensor * s_mask; // F32 [1, n_kv]
|
||||
|
||||
const llama_kv_cache_recurrent * kv_self;
|
||||
const llama_kv_cache_recurrent_state * kv_state;
|
||||
};
|
||||
|
||||
class llm_graph_input_cross_embd : public llm_graph_input_i {
|
||||
|
@ -247,10 +248,10 @@ public:
|
|||
llm_graph_input_attn_kv_unified(
|
||||
const llama_hparams & hparams,
|
||||
const llama_cparams & cparams,
|
||||
const llama_kv_cache_unified * kv_self) :
|
||||
const llama_kv_cache_unified_state * kv_state) :
|
||||
hparams(hparams),
|
||||
cparams(cparams),
|
||||
kv_self(kv_self) {
|
||||
kv_state(kv_state) {
|
||||
}
|
||||
~llm_graph_input_attn_kv_unified() = default;
|
||||
|
||||
|
@ -264,7 +265,7 @@ public:
|
|||
const llama_hparams & hparams;
|
||||
const llama_cparams & cparams;
|
||||
|
||||
const llama_kv_cache_unified * kv_self;
|
||||
const llama_kv_cache_unified_state * kv_state;
|
||||
};
|
||||
|
||||
class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
|
||||
|
@ -272,10 +273,10 @@ public:
|
|||
llm_graph_input_attn_kv_unified_iswa(
|
||||
const llama_hparams & hparams,
|
||||
const llama_cparams & cparams,
|
||||
const llama_kv_cache_unified_iswa * kv_self) :
|
||||
const llama_kv_cache_unified_iswa_state * kv_state) :
|
||||
hparams(hparams),
|
||||
cparams(cparams),
|
||||
kv_self(kv_self) {
|
||||
kv_state(kv_state) {
|
||||
}
|
||||
~llm_graph_input_attn_kv_unified_iswa() = default;
|
||||
|
||||
|
@ -292,7 +293,7 @@ public:
|
|||
const llama_hparams & hparams;
|
||||
const llama_cparams & cparams;
|
||||
|
||||
const llama_kv_cache_unified_iswa * kv_self;
|
||||
const llama_kv_cache_unified_iswa_state * kv_state;
|
||||
};
|
||||
|
||||
class llm_graph_input_attn_cross : public llm_graph_input_i {
|
||||
|
@ -383,10 +384,10 @@ struct llm_graph_params {
|
|||
ggml_backend_sched_t sched;
|
||||
ggml_backend_t backend_cpu;
|
||||
|
||||
const llama_adapter_cvec * cvec;
|
||||
const llama_adapter_loras * loras;
|
||||
const llama_memory_i * memory;
|
||||
const llama_cross * cross;
|
||||
const llama_adapter_cvec * cvec;
|
||||
const llama_adapter_loras * loras;
|
||||
const llama_memory_state_i * mstate;
|
||||
const llama_cross * cross;
|
||||
|
||||
int32_t n_outputs;
|
||||
|
||||
|
@ -435,10 +436,10 @@ struct llm_graph_context {
|
|||
|
||||
ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
|
||||
|
||||
const llama_adapter_cvec * cvec;
|
||||
const llama_adapter_loras * loras;
|
||||
const llama_memory_i * memory;
|
||||
const llama_cross * cross;
|
||||
const llama_adapter_cvec * cvec;
|
||||
const llama_adapter_loras * loras;
|
||||
const llama_memory_state_i * mstate;
|
||||
const llama_cross * cross;
|
||||
|
||||
const llm_graph_cb & cb_func;
|
||||
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -2,6 +2,7 @@
|
|||
|
||||
#include "llama.h"
|
||||
#include "llama-io.h"
|
||||
#include "llama-batch.h"
|
||||
#include "llama-graph.h"
|
||||
#include "llama-memory.h"
|
||||
#include "llama-kv-cells.h"
|
||||
|
@ -14,48 +15,35 @@
|
|||
|
||||
struct llama_cparams;
|
||||
struct llama_hparams;
|
||||
struct llama_ubatch;
|
||||
struct llama_sbatch;
|
||||
struct llama_model;
|
||||
struct llama_context;
|
||||
|
||||
struct llama_kv_cache : public llama_memory_i {
|
||||
virtual ~llama_kv_cache() = default;
|
||||
|
||||
// call if batch processing fails - restores the cache state
|
||||
virtual void restore() = 0;
|
||||
// split the input batch into a set of ubatches and verify that they can fit into the cache
|
||||
// return a state object containing the ubatches and KV cache state required to process them
|
||||
// check the llama_memory_state_i::get_status() for the result
|
||||
virtual llama_memory_state_ptr init_batch(
|
||||
const llama_batch & batch,
|
||||
uint32_t n_ubatch,
|
||||
bool embd_pooled,
|
||||
bool logits_all) = 0;
|
||||
|
||||
// call after successful batch processing - clears any pending state
|
||||
virtual void commit() = 0;
|
||||
// simulate full cache, used for allocating worst-case compute buffers
|
||||
virtual llama_memory_state_ptr init_full() = 0;
|
||||
|
||||
// process any pending defrag/shift/etc. operations
|
||||
// optionally call once before processing a new batch
|
||||
// return true if any operations were performed
|
||||
virtual bool update(llama_context & lctx) = 0;
|
||||
|
||||
// schedule a defrag if the fragmentation threshold is exceeded. otherwise, do nothing
|
||||
// TODO: change to
|
||||
// llama_memory_state_ptr init_defrag(float thold) = 0;
|
||||
//
|
||||
virtual void defrag_sched(float thold) = 0;
|
||||
|
||||
// simulate full cache, used for allocating worst-case compute buffers
|
||||
// TODO: remove
|
||||
virtual void set_full() = 0;
|
||||
|
||||
//
|
||||
// batch processing
|
||||
//
|
||||
|
||||
// =============================================================================================================
|
||||
// TODO: refactor and simplify this [TAG: KV_API]
|
||||
|
||||
virtual llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) = 0;
|
||||
|
||||
// different KV caches require different batch splitting strategies
|
||||
virtual llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const = 0;
|
||||
|
||||
// find an empty slot of size "n_tokens" in the cache
|
||||
virtual bool find_slot(const llama_ubatch & batch) = 0;
|
||||
|
||||
// =============================================================================================================
|
||||
|
||||
// getters
|
||||
virtual bool get_can_shift() const = 0;
|
||||
|
||||
|
@ -69,25 +57,6 @@ struct llama_kv_cache : public llama_memory_i {
|
|||
virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) = 0;
|
||||
};
|
||||
|
||||
//
|
||||
// llama_kv_cache_guard
|
||||
//
|
||||
|
||||
struct llama_kv_cache_guard {
|
||||
llama_kv_cache_guard(llama_kv_cache * kv) : kv(kv) {}
|
||||
|
||||
~llama_kv_cache_guard() {
|
||||
kv->restore();
|
||||
}
|
||||
|
||||
void commit() {
|
||||
kv->commit();
|
||||
}
|
||||
|
||||
private:
|
||||
llama_kv_cache * kv;
|
||||
};
|
||||
|
||||
//
|
||||
// llama_kv_cache_unified
|
||||
//
|
||||
|
@ -133,23 +102,18 @@ public:
|
|||
// llama_kv_cache
|
||||
//
|
||||
|
||||
void restore() override;
|
||||
void commit() override;
|
||||
llama_memory_state_ptr init_batch(
|
||||
const llama_batch & batch,
|
||||
uint32_t n_ubatch,
|
||||
bool embd_pooled,
|
||||
bool logits_all) override;
|
||||
|
||||
bool update(llama_context & ctx) override;
|
||||
llama_memory_state_ptr init_full() override;
|
||||
|
||||
bool update(llama_context & lctx) override;
|
||||
|
||||
void defrag_sched(float thold) override;
|
||||
|
||||
void set_full() override;
|
||||
|
||||
llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
|
||||
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
|
||||
|
||||
// updates the cache head
|
||||
// Note: On success, it's important that cache.head points
|
||||
// to the first cell of the slot.
|
||||
bool find_slot(const llama_ubatch & batch) override;
|
||||
|
||||
bool get_can_shift() const override;
|
||||
|
||||
// state write/load
|
||||
|
@ -161,18 +125,40 @@ public:
|
|||
// llama_kv_cache_unified specific API
|
||||
//
|
||||
|
||||
uint32_t get_n() const;
|
||||
uint32_t get_size() const;
|
||||
|
||||
//
|
||||
// graph_build API
|
||||
//
|
||||
|
||||
uint32_t get_n_kv() const;
|
||||
|
||||
// get views of the current state of the cache
|
||||
ggml_tensor * get_k(ggml_context * ctx, int32_t il) const;
|
||||
ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
|
||||
ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv) const;
|
||||
ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const;
|
||||
|
||||
// store k_cur and v_cur in the cache based on the current head location
|
||||
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const;
|
||||
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const;
|
||||
// store k_cur and v_cur in the cache based on the provided head location
|
||||
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il, uint32_t head_cur) const;
|
||||
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il, uint32_t head_cur) const;
|
||||
|
||||
void prune_swa(llama_seq_id seq_id, llama_pos pmin, llama_pos pmax);
|
||||
//
|
||||
// preparation API
|
||||
//
|
||||
|
||||
// find places for the provided ubatches in the cache, returns the head locations
|
||||
// return empty vector on failure
|
||||
std::vector<uint32_t> prepare(const std::vector<llama_ubatch> & ubatches);
|
||||
|
||||
// return the cell position where we can insert the ubatch
|
||||
// return -1 on failure to find a contiguous slot of kv cells
|
||||
int32_t find_slot(const llama_ubatch & ubatch) const;
|
||||
|
||||
// emplace the ubatch context into slot: [head_cur, head_cur + ubatch.n_tokens)
|
||||
void apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch);
|
||||
|
||||
//
|
||||
// set_input API
|
||||
//
|
||||
|
||||
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
|
||||
void set_input_k_shift (ggml_tensor * dst) const;
|
||||
|
@ -194,11 +180,9 @@ private:
|
|||
bool do_defrag = false;
|
||||
bool v_trans = true; // the value tensor is transposed
|
||||
|
||||
uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot())
|
||||
|
||||
// computed before each graph build
|
||||
// TODO: cells should start to maintain this value dynamically based on the edits
|
||||
uint32_t n = 0;
|
||||
// the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
|
||||
// note: this is not part of the KV state and it's only used to speed-up the find_slot() method
|
||||
uint32_t head = 0;
|
||||
|
||||
const uint32_t n_seq_max = 1;
|
||||
|
||||
|
@ -220,24 +204,6 @@ private:
|
|||
// model layer id -> KV cache layer id
|
||||
std::unordered_map<int32_t, int32_t> map_layer_ids;
|
||||
|
||||
// recovery information used to restore the KV cells to their original state in case of a failure
|
||||
// TODO: do not store as a state in the llama_kv_cache object, instead return upon batch preparation
|
||||
// to achieve that, first need to refactor the llama_kv_cache interface [TAG: KV_API]
|
||||
struct {
|
||||
void clear() {
|
||||
states.clear();
|
||||
}
|
||||
|
||||
struct state {
|
||||
uint32_t i;
|
||||
|
||||
llama_kv_cells_unified cells;
|
||||
};
|
||||
|
||||
// stack with the partial states before each ubatch
|
||||
std::vector<state> states;
|
||||
} recovery;
|
||||
|
||||
// defrag
|
||||
struct {
|
||||
std::vector<uint32_t> ids;
|
||||
|
@ -279,13 +245,88 @@ private:
|
|||
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
|
||||
};
|
||||
|
||||
class llama_kv_cache_unified_state : public llama_memory_state_i {
|
||||
public:
|
||||
// used for errors
|
||||
llama_kv_cache_unified_state(llama_memory_status status);
|
||||
|
||||
// used to create a full-cache state
|
||||
llama_kv_cache_unified_state(
|
||||
llama_memory_status status,
|
||||
llama_kv_cache_unified * kv);
|
||||
|
||||
// used to create a state from a batch
|
||||
llama_kv_cache_unified_state(
|
||||
llama_memory_status status,
|
||||
llama_kv_cache_unified * kv,
|
||||
llama_sbatch sbatch,
|
||||
std::vector<uint32_t> heads,
|
||||
std::vector<llama_ubatch> ubatches);
|
||||
|
||||
virtual ~llama_kv_cache_unified_state();
|
||||
|
||||
//
|
||||
// llama_memory_state_i
|
||||
//
|
||||
|
||||
bool next() override;
|
||||
bool apply() override;
|
||||
|
||||
std::vector<int64_t> & out_ids() override;
|
||||
|
||||
llama_memory_status get_status() const override;
|
||||
const llama_ubatch & get_ubatch() const override;
|
||||
|
||||
//
|
||||
// llama_kv_cache_unified_state specific API
|
||||
//
|
||||
|
||||
uint32_t get_n_kv() const;
|
||||
|
||||
// get views of the current state of the cache
|
||||
ggml_tensor * get_k(ggml_context * ctx, int32_t il) const;
|
||||
ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
|
||||
|
||||
// store k_cur and v_cur in the cache based on the provided head location
|
||||
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const;
|
||||
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const;
|
||||
|
||||
void set_input_k_shift(ggml_tensor * dst) const;
|
||||
|
||||
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
|
||||
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
||||
|
||||
private:
|
||||
const llama_memory_status status;
|
||||
|
||||
llama_kv_cache_unified * kv;
|
||||
|
||||
llama_sbatch sbatch;
|
||||
|
||||
// the index of the next ubatch to process
|
||||
size_t i_next = 0;
|
||||
|
||||
std::vector<uint32_t> heads;
|
||||
std::vector<llama_ubatch> ubatches;
|
||||
|
||||
//
|
||||
// data needed for building the compute graph for the current ubatch:
|
||||
//
|
||||
|
||||
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
||||
// as the cache gets filled, the benefit from this heuristic disappears
|
||||
int32_t n_kv;
|
||||
|
||||
// the beginning of the current slot in which the ubatch will be inserted
|
||||
int32_t head;
|
||||
};
|
||||
|
||||
//
|
||||
// llama_kv_cache_unified_iswa
|
||||
//
|
||||
|
||||
// utilizes two instances of llama_kv_cache_unified
|
||||
// the first instance is for the non-SWA layers of the model and the second instance is for the SWA layers
|
||||
// upon successful commit, the SWA cache removes old tokens outside the n_swa window
|
||||
|
||||
class llama_kv_cache_unified_iswa : public llama_kv_cache {
|
||||
public:
|
||||
|
@ -298,7 +339,7 @@ public:
|
|||
bool swa_full,
|
||||
uint32_t kv_size,
|
||||
uint32_t n_seq_max,
|
||||
uint32_t n_batch,
|
||||
uint32_t n_ubatch,
|
||||
uint32_t n_pad);
|
||||
|
||||
~llama_kv_cache_unified_iswa() = default;
|
||||
|
@ -322,20 +363,18 @@ public:
|
|||
// llama_kv_cache
|
||||
//
|
||||
|
||||
void restore() override;
|
||||
void commit() override;
|
||||
llama_memory_state_ptr init_batch(
|
||||
const llama_batch & batch,
|
||||
uint32_t n_ubatch,
|
||||
bool embd_pooled,
|
||||
bool logits_all) override;
|
||||
|
||||
bool update(llama_context & ctx) override;
|
||||
llama_memory_state_ptr init_full() override;
|
||||
|
||||
bool update(llama_context & lctx) override;
|
||||
|
||||
void defrag_sched(float thold) override;
|
||||
|
||||
void set_full() override;
|
||||
|
||||
llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
|
||||
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
|
||||
|
||||
bool find_slot(const llama_ubatch & batch) override;
|
||||
|
||||
bool get_can_shift() const override;
|
||||
|
||||
// state write/load
|
||||
|
@ -347,58 +386,80 @@ public:
|
|||
// llama_kv_cache_unified_iswa specific API
|
||||
//
|
||||
|
||||
llama_kv_cache_unified * get_kv_base() const;
|
||||
llama_kv_cache_unified * get_kv_swa () const;
|
||||
llama_kv_cache_unified * get_base() const;
|
||||
llama_kv_cache_unified * get_swa () const;
|
||||
|
||||
private:
|
||||
const llama_hparams & hparams;
|
||||
|
||||
bool do_prune = true;
|
||||
|
||||
struct {
|
||||
struct entry {
|
||||
llama_pos pmin;
|
||||
llama_pos pmax;
|
||||
};
|
||||
|
||||
void clear() {
|
||||
pos.clear();
|
||||
}
|
||||
|
||||
// used to perform SWA pruning of old tokens
|
||||
std::unordered_map<llama_seq_id, entry> pos;
|
||||
} pending;
|
||||
|
||||
std::unique_ptr<llama_kv_cache_unified> kv_base;
|
||||
std::unique_ptr<llama_kv_cache_unified> kv_swa;
|
||||
};
|
||||
|
||||
class llama_kv_cache_unified_iswa_state : public llama_memory_state_i {
|
||||
public:
|
||||
// used for errors
|
||||
llama_kv_cache_unified_iswa_state(llama_memory_status status);
|
||||
|
||||
// used to create a full-cache state
|
||||
llama_kv_cache_unified_iswa_state(
|
||||
llama_memory_status status,
|
||||
llama_kv_cache_unified_iswa * kv);
|
||||
|
||||
// used to create a state from a batch
|
||||
llama_kv_cache_unified_iswa_state(
|
||||
llama_memory_status status,
|
||||
llama_kv_cache_unified_iswa * kv,
|
||||
llama_sbatch sbatch,
|
||||
std::vector<uint32_t> heads_base,
|
||||
std::vector<uint32_t> heads_swa,
|
||||
std::vector<llama_ubatch> ubatches);
|
||||
|
||||
virtual ~llama_kv_cache_unified_iswa_state();
|
||||
|
||||
//
|
||||
// llama_memory_state_i
|
||||
//
|
||||
|
||||
bool next() override;
|
||||
bool apply() override;
|
||||
|
||||
std::vector<int64_t> & out_ids() override;
|
||||
|
||||
llama_memory_status get_status() const override;
|
||||
const llama_ubatch & get_ubatch() const override;
|
||||
|
||||
//
|
||||
// llama_kv_cache_unified_iswa_state specific API
|
||||
//
|
||||
|
||||
const llama_kv_cache_unified_state * get_base() const;
|
||||
const llama_kv_cache_unified_state * get_swa() const;
|
||||
|
||||
private:
|
||||
const llama_memory_status status;
|
||||
|
||||
//llama_kv_cache_unified_iswa * kv;
|
||||
|
||||
llama_sbatch sbatch;
|
||||
|
||||
// the index of the next ubatch to process
|
||||
size_t i_next = 0;
|
||||
|
||||
std::vector<llama_ubatch> ubatches;
|
||||
|
||||
std::unique_ptr<llama_kv_cache_unified_state> state_base;
|
||||
std::unique_ptr<llama_kv_cache_unified_state> state_swa;
|
||||
};
|
||||
|
||||
//
|
||||
// llama_kv_cache_recurrent
|
||||
//
|
||||
|
||||
// TODO: extract the KV cache state used for graph computation into llama_kv_cache_recurrent_state_i
|
||||
// see the implementation of llama_kv_cache_unified_state_i for an example how to do it
|
||||
class llama_kv_cache_recurrent : public llama_kv_cache {
|
||||
public:
|
||||
struct kv_cell {
|
||||
llama_pos pos = -1;
|
||||
int32_t src = -1; // used to copy states
|
||||
int32_t tail = -1;
|
||||
|
||||
std::set<llama_seq_id> seq_id;
|
||||
|
||||
bool has_seq_id(const llama_seq_id & id) const {
|
||||
return seq_id.find(id) != seq_id.end();
|
||||
}
|
||||
|
||||
bool is_empty() const {
|
||||
return seq_id.empty();
|
||||
}
|
||||
|
||||
bool is_same_seq(const kv_cell & other) const {
|
||||
return seq_id == other.seq_id;
|
||||
}
|
||||
};
|
||||
|
||||
llama_kv_cache_recurrent(
|
||||
const llama_model & model,
|
||||
ggml_type type_k,
|
||||
|
@ -428,19 +489,22 @@ public:
|
|||
// llama_kv_cache
|
||||
//
|
||||
|
||||
void restore() override;
|
||||
void commit() override;
|
||||
llama_memory_state_ptr init_batch(
|
||||
const llama_batch & batch,
|
||||
uint32_t n_ubatch,
|
||||
bool embd_pooled,
|
||||
bool logits_all) override;
|
||||
|
||||
bool update(llama_context & ctx) override;
|
||||
llama_memory_state_ptr init_full() override;
|
||||
|
||||
bool update(llama_context & lctx) override;
|
||||
|
||||
void defrag_sched(float thold) override;
|
||||
|
||||
void set_full() override;
|
||||
bool prepare(const std::vector<llama_ubatch> & ubatches);
|
||||
|
||||
llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
|
||||
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
|
||||
|
||||
bool find_slot(const llama_ubatch & batch) override;
|
||||
// find a contiguous slot of kv cells and emplace the ubatch there
|
||||
bool find_slot(const llama_ubatch & ubatch);
|
||||
|
||||
bool get_can_shift() const override;
|
||||
|
||||
|
@ -460,6 +524,27 @@ public:
|
|||
// computed before each graph build
|
||||
uint32_t n = 0;
|
||||
|
||||
// TODO: optimize for recurrent state needs
|
||||
struct kv_cell {
|
||||
llama_pos pos = -1;
|
||||
int32_t src = -1; // used to copy states
|
||||
int32_t tail = -1;
|
||||
|
||||
std::set<llama_seq_id> seq_id;
|
||||
|
||||
bool has_seq_id(const llama_seq_id & id) const {
|
||||
return seq_id.find(id) != seq_id.end();
|
||||
}
|
||||
|
||||
bool is_empty() const {
|
||||
return seq_id.empty();
|
||||
}
|
||||
|
||||
bool is_same_seq(const kv_cell & other) const {
|
||||
return seq_id == other.seq_id;
|
||||
}
|
||||
};
|
||||
|
||||
std::vector<kv_cell> cells;
|
||||
|
||||
std::vector<ggml_tensor *> k_l; // per layer
|
||||
|
@ -469,26 +554,11 @@ private:
|
|||
//const llama_model & model;
|
||||
const llama_hparams & hparams;
|
||||
|
||||
// commit/restore cache
|
||||
// TODO: rework for recurrent cache
|
||||
struct slot_range {
|
||||
uint32_t c0 = 0; // note: these are cell indices, not sequence positions
|
||||
uint32_t c1 = 0;
|
||||
};
|
||||
|
||||
// pending cell updates that are not yet committed
|
||||
struct {
|
||||
std::vector<slot_range> ranges;
|
||||
} pending;
|
||||
|
||||
const uint32_t n_seq_max = 1;
|
||||
|
||||
std::vector<ggml_context_ptr> ctxs;
|
||||
std::vector<ggml_backend_buffer_ptr> bufs;
|
||||
|
||||
// find how many cells are currently in use
|
||||
uint32_t cell_max() const;
|
||||
|
||||
size_t total_size() const;
|
||||
|
||||
size_t size_k_bytes() const;
|
||||
|
@ -500,3 +570,67 @@ private:
|
|||
bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
|
||||
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
|
||||
};
|
||||
|
||||
class llama_kv_cache_recurrent_state : public llama_memory_state_i {
|
||||
public:
|
||||
// used for errors
|
||||
llama_kv_cache_recurrent_state(llama_memory_status status);
|
||||
|
||||
// used to create a full-cache state
|
||||
llama_kv_cache_recurrent_state(
|
||||
llama_memory_status status,
|
||||
llama_kv_cache_recurrent * kv);
|
||||
|
||||
// used to create a state from a batch
|
||||
llama_kv_cache_recurrent_state(
|
||||
llama_memory_status status,
|
||||
llama_kv_cache_recurrent * kv,
|
||||
llama_sbatch sbatch,
|
||||
std::vector<llama_ubatch> ubatches);
|
||||
|
||||
virtual ~llama_kv_cache_recurrent_state();
|
||||
|
||||
//
|
||||
// llama_memory_state_i
|
||||
//
|
||||
|
||||
bool next() override;
|
||||
bool apply() override;
|
||||
|
||||
std::vector<int64_t> & out_ids() override;
|
||||
|
||||
llama_memory_status get_status() const override;
|
||||
const llama_ubatch & get_ubatch() const override;
|
||||
|
||||
//
|
||||
// llama_kv_cache_recurrent_state specific API
|
||||
//
|
||||
|
||||
uint32_t get_n_kv() const;
|
||||
uint32_t get_head() const;
|
||||
uint32_t get_size() const;
|
||||
|
||||
ggml_tensor * get_k_l(int32_t il) const;
|
||||
ggml_tensor * get_v_l(int32_t il) const;
|
||||
|
||||
int32_t s_copy(int i) const;
|
||||
float s_mask(int i) const;
|
||||
|
||||
private:
|
||||
const llama_memory_status status;
|
||||
|
||||
llama_kv_cache_recurrent * kv;
|
||||
|
||||
llama_sbatch sbatch;
|
||||
|
||||
size_t i_next = 0;
|
||||
|
||||
std::vector<llama_ubatch> ubatches;
|
||||
|
||||
//
|
||||
// data needed for building the compute graph for the current ubatch:
|
||||
// TODO: extract all the state like `head` and `n` here
|
||||
//
|
||||
|
||||
const bool is_full = false;
|
||||
};
|
||||
|
|
|
@ -68,12 +68,6 @@ public:
|
|||
// the index of the last cell that is used + 1
|
||||
// return 0 if no cells are used
|
||||
uint32_t used_max_p1() const {
|
||||
#if 0
|
||||
if (!seq_pos[0].empty()) printf("kv_cells: min[0] = %5d, max[0] = %5d\n", *seq_pos[0].begin(), *seq_pos[0].rbegin());
|
||||
if (!seq_pos[1].empty()) printf("kv_cells: min[1] = %5d, max[1] = %5d\n", *seq_pos[1].begin(), *seq_pos[1].rbegin());
|
||||
if (!seq_pos[2].empty()) printf("kv_cells: min[2] = %5d, max[2] = %5d\n", *seq_pos[2].begin(), *seq_pos[2].rbegin());
|
||||
#endif
|
||||
|
||||
return used.empty() ? 0 : *used.rbegin() + 1;
|
||||
}
|
||||
|
||||
|
@ -144,6 +138,19 @@ public:
|
|||
}
|
||||
}
|
||||
|
||||
// clear a non-empty cell
|
||||
void rm(uint32_t i) {
|
||||
assert(i < pos.size());
|
||||
assert(pos[i] != -1);
|
||||
|
||||
seq_pos_rm(i);
|
||||
|
||||
pos[i] = -1;
|
||||
seq[i].reset();
|
||||
|
||||
used.erase(i);
|
||||
}
|
||||
|
||||
// note: call only if the cell has seq_id
|
||||
// return true if the cell becomes empty
|
||||
bool seq_rm(uint32_t i, llama_seq_id seq_id) {
|
||||
|
@ -196,6 +203,15 @@ public:
|
|||
return false;
|
||||
}
|
||||
|
||||
// number of different sequences in the cell
|
||||
int seq_count(uint32_t i) const {
|
||||
assert(i < pos.size());
|
||||
assert(pos[i] != -1);
|
||||
|
||||
return seq[i].count();
|
||||
}
|
||||
|
||||
// check if the cell contains seq_id
|
||||
bool seq_has(uint32_t i, llama_seq_id seq_id) const {
|
||||
assert(i < pos.size());
|
||||
assert(seq_id >= 0);
|
||||
|
@ -213,6 +229,20 @@ public:
|
|||
seq_pos[seq_id].insert(pos[i]);
|
||||
}
|
||||
|
||||
// return the sequence id of this cell
|
||||
// note: call only for cells with exactly one sequence
|
||||
llama_seq_id seq_get(uint32_t i) const {
|
||||
assert(seq[i].count() == 1);
|
||||
|
||||
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
|
||||
if (seq[i].test(s)) {
|
||||
return s;
|
||||
}
|
||||
}
|
||||
|
||||
return -1;
|
||||
}
|
||||
|
||||
// the minimum position of sequence seq_id currently present in any of the cells
|
||||
// return -1 if the sequence is not present
|
||||
llama_pos seq_pos_min(llama_seq_id seq_id) const {
|
||||
|
@ -268,6 +298,7 @@ public:
|
|||
void pos_set(uint32_t i, llama_pos p) {
|
||||
assert(i < pos.size());
|
||||
assert(pos[i] == -1);
|
||||
assert(seq[i].none());
|
||||
|
||||
pos[i] = p;
|
||||
|
||||
|
|
|
@ -2,6 +2,11 @@
|
|||
|
||||
#include "llama.h"
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
struct llama_ubatch;
|
||||
|
||||
struct llama_memory_params {
|
||||
// kv cache
|
||||
ggml_type type_k;
|
||||
|
@ -30,3 +35,42 @@ public:
|
|||
|
||||
virtual bool get_can_edit() const = 0;
|
||||
};
|
||||
|
||||
enum llama_memory_status {
|
||||
LLAMA_MEMORY_STATUS_SUCCESS = 0,
|
||||
LLAMA_MEMORY_STATUS_FAILED_PREPARE,
|
||||
LLAMA_MEMORY_STATUS_FAILED_COMPUTE,
|
||||
};
|
||||
|
||||
// the interface for managing the memory state during batch processing
|
||||
// this interface is implemented per memory type. see:
|
||||
// - llama_kv_cache_unified_state
|
||||
// - llama_kv_cache_unified_iswa_state
|
||||
// ...
|
||||
//
|
||||
// the only method that can mutate the memory and the memory state is llama_memory_i::apply()
|
||||
//
|
||||
// TODO: rename to llama_memory_context_i ?
|
||||
class llama_memory_state_i {
|
||||
public:
|
||||
virtual ~llama_memory_state_i() = default;
|
||||
|
||||
// consume the current ubatch from the state and proceed to the next one
|
||||
// return false if we are done
|
||||
virtual bool next() = 0;
|
||||
|
||||
// apply the memory state for the current ubatch to the memory object
|
||||
// return false on failure
|
||||
virtual bool apply() = 0;
|
||||
|
||||
// TODO: this might get reworked in the future when refactoring llama_batch
|
||||
virtual std::vector<int64_t> & out_ids() = 0;
|
||||
|
||||
// get the current ubatch
|
||||
virtual const llama_ubatch & get_ubatch() const = 0;
|
||||
|
||||
// get the status of the memory state
|
||||
virtual llama_memory_status get_status() const = 0;
|
||||
};
|
||||
|
||||
using llama_memory_state_ptr = std::unique_ptr<llama_memory_state_i>;
|
||||
|
|
|
@ -8992,9 +8992,9 @@ struct llm_build_mamba : public llm_graph_context {
|
|||
ggml_tensor * state_mask,
|
||||
const llama_ubatch & ubatch,
|
||||
int il) const {
|
||||
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
|
||||
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
|
||||
|
||||
const auto kv_head = kv_self->head;
|
||||
const auto kv_head = kv_state->get_head();
|
||||
|
||||
const int64_t d_conv = hparams.ssm_d_conv;
|
||||
const int64_t d_inner = hparams.ssm_d_inner;
|
||||
|
@ -9012,8 +9012,8 @@ struct llm_build_mamba : public llm_graph_context {
|
|||
GGML_ASSERT(ubatch.equal_seqs);
|
||||
GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
|
||||
|
||||
ggml_tensor * conv_states_all = kv_self->k_l[il];
|
||||
ggml_tensor * ssm_states_all = kv_self->v_l[il];
|
||||
ggml_tensor * conv_states_all = kv_state->get_k_l(il);
|
||||
ggml_tensor * ssm_states_all = kv_state->get_v_l(il);
|
||||
|
||||
// (ab)using the KV cache to store the states
|
||||
ggml_tensor * conv = build_copy_mask_state(
|
||||
|
@ -11740,7 +11740,7 @@ struct llm_build_rwkv6_base : public llm_graph_context {
|
|||
ggml_tensor * state_mask,
|
||||
const llama_ubatch & ubatch,
|
||||
int il) const {
|
||||
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
|
||||
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
|
||||
|
||||
const auto n_tokens = ubatch.n_tokens;
|
||||
const auto n_seqs = ubatch.n_seqs;
|
||||
|
@ -11750,7 +11750,7 @@ struct llm_build_rwkv6_base : public llm_graph_context {
|
|||
const auto n_head = n_embd / head_size;
|
||||
const auto n_head_kv = hparams.n_head_kv(il);
|
||||
|
||||
const auto kv_head = kv_self->head;
|
||||
const auto kv_head = kv_state->get_head();
|
||||
|
||||
const auto & layer = model.layers[il];
|
||||
|
||||
|
@ -11862,7 +11862,7 @@ struct llm_build_rwkv6_base : public llm_graph_context {
|
|||
}
|
||||
|
||||
ggml_tensor * wkv_state = build_copy_mask_state(
|
||||
gf, kv_self->v_l[il], state_copy, state_mask,
|
||||
gf, kv_state->get_v_l(il), state_copy, state_mask,
|
||||
hparams.n_embd_v_s(), n_seqs);
|
||||
|
||||
ggml_tensor * wkv_output;
|
||||
|
@ -11881,9 +11881,9 @@ struct llm_build_rwkv6_base : public llm_graph_context {
|
|||
wkv_state,
|
||||
ggml_view_1d(
|
||||
ctx0,
|
||||
kv_self->v_l[il],
|
||||
kv_state->get_v_l(il),
|
||||
hparams.n_embd_v_s() * n_seqs,
|
||||
hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_self->v_l[il])
|
||||
hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_state->get_v_l(il))
|
||||
)
|
||||
)
|
||||
);
|
||||
|
@ -12136,7 +12136,7 @@ struct llm_build_rwkv7_base : public llm_graph_context {
|
|||
ggml_tensor *& first_layer_value,
|
||||
const llama_ubatch & ubatch,
|
||||
int il) const {
|
||||
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
|
||||
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
|
||||
|
||||
const auto n_tokens = ubatch.n_tokens;
|
||||
const auto n_seqs = ubatch.n_seqs;
|
||||
|
@ -12145,7 +12145,7 @@ struct llm_build_rwkv7_base : public llm_graph_context {
|
|||
const auto head_count = n_embd / head_size;
|
||||
const auto n_seq_tokens = ubatch.n_seq_tokens;
|
||||
|
||||
const auto kv_head = kv_self->head;
|
||||
const auto kv_head = kv_state->get_head();
|
||||
|
||||
const auto & layer = model.layers[il];
|
||||
|
||||
|
@ -12216,7 +12216,7 @@ struct llm_build_rwkv7_base : public llm_graph_context {
|
|||
a = ggml_reshape_3d(ctx0, a, head_size, head_count, n_tokens);
|
||||
|
||||
ggml_tensor * wkv_state = build_copy_mask_state(
|
||||
gf, kv_self->v_l[il], state_copy, state_mask,
|
||||
gf, kv_state->get_v_l(il), state_copy, state_mask,
|
||||
hparams.n_embd_v_s(), n_seqs);
|
||||
|
||||
ggml_tensor * wkv_output = ggml_rwkv_wkv7(ctx0, r, w, k, v, ggml_neg(ctx0, kk), ggml_mul(ctx0, kk, a), wkv_state);
|
||||
|
@ -12230,9 +12230,9 @@ struct llm_build_rwkv7_base : public llm_graph_context {
|
|||
wkv_state,
|
||||
ggml_view_1d(
|
||||
ctx0,
|
||||
kv_self->v_l[il],
|
||||
kv_state->get_v_l(il),
|
||||
hparams.n_embd_v_s() * n_seqs,
|
||||
hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_self->v_l[il])
|
||||
hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_state->get_v_l(il))
|
||||
)
|
||||
)
|
||||
);
|
||||
|
@ -13330,7 +13330,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
|||
params.swa_full,
|
||||
cparams.n_ctx,
|
||||
cparams.n_seq_max,
|
||||
cparams.n_batch,
|
||||
cparams.n_ubatch,
|
||||
padding);
|
||||
} else {
|
||||
GGML_ASSERT(!hparams.is_swa_any());
|
||||
|
@ -13693,6 +13693,10 @@ int32_t llama_model_n_head_kv(const llama_model * model) {
|
|||
return model->hparams.n_head_kv();
|
||||
}
|
||||
|
||||
int32_t llama_model_n_swa(const llama_model * model) {
|
||||
return model->hparams.n_swa;
|
||||
}
|
||||
|
||||
// deprecated
|
||||
int32_t llama_n_ctx_train(const llama_model * model) {
|
||||
return llama_model_n_ctx_train(model);
|
||||
|
|
|
@ -3,7 +3,6 @@
|
|||
|
||||
#include "ggml.h"
|
||||
#include "llama.h"
|
||||
#include "clip.h"
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
|
Binary file not shown.
|
@ -2016,11 +2016,6 @@ struct server_context {
|
|||
params_base.n_cache_reuse = 0;
|
||||
SRV_WRN("%s\n", "cache_reuse is not supported by this context, it will be disabled");
|
||||
}
|
||||
|
||||
if (!params_base.speculative.model.path.empty()) {
|
||||
SRV_ERR("%s\n", "err: speculative decode is not supported by this context");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
|
@ -3214,9 +3209,18 @@ struct server_context {
|
|||
}
|
||||
|
||||
if (slot.n_past > 0 && slot.n_past < (int) slot.cache_tokens.size()) {
|
||||
if (llama_kv_self_seq_pos_min(ctx, slot.id) > 0) {
|
||||
const auto pos_min = llama_kv_self_seq_pos_min(ctx, slot.id);
|
||||
if (pos_min == -1) {
|
||||
SLT_ERR(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min);
|
||||
GGML_ABORT("pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237");
|
||||
}
|
||||
|
||||
const auto n_swa = llama_model_n_swa(model);
|
||||
if (pos_min > slot.n_past - n_swa) {
|
||||
SLT_WRN(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min, n_swa);
|
||||
SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA, see %s)\n",
|
||||
"https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
|
||||
llama_kv_self_seq_rm(ctx, slot.id, 0, -1);
|
||||
slot.n_past = 0;
|
||||
}
|
||||
}
|
||||
|
@ -3379,8 +3383,10 @@ struct server_context {
|
|||
}
|
||||
}
|
||||
|
||||
int32_t i_next = 0;
|
||||
|
||||
// process the created batch of tokens
|
||||
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
|
||||
for (int32_t i = 0; i < batch.n_tokens; i = i_next) {
|
||||
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
|
||||
|
||||
llama_batch batch_view = {
|
||||
|
@ -3425,13 +3431,18 @@ struct server_context {
|
|||
|
||||
// retry with half the batch size to try to find a free slot in the KV cache
|
||||
n_batch /= 2;
|
||||
i -= n_batch;
|
||||
|
||||
SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size - try increasing it via the context size or enable defragmentation, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret);
|
||||
SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret);
|
||||
|
||||
continue; // continue loop of n_batch
|
||||
}
|
||||
|
||||
// move the head of the batch forward with the number of tokens we just processed
|
||||
i_next = i + n_tokens;
|
||||
|
||||
// on successful decode, restore the original batch size
|
||||
n_batch = llama_n_batch(ctx);
|
||||
|
||||
for (auto & slot : slots) {
|
||||
if (slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) {
|
||||
continue; // continue loop of slots
|
||||
|
|
|
@ -5,21 +5,24 @@ import { AppContextProvider, useAppContext } from './utils/app.context';
|
|||
import ChatScreen from './components/ChatScreen';
|
||||
import SettingDialog from './components/SettingDialog';
|
||||
import { Toaster } from 'react-hot-toast';
|
||||
import { ModalProvider } from './components/ModalProvider';
|
||||
|
||||
function App() {
|
||||
return (
|
||||
<HashRouter>
|
||||
<div className="flex flex-row drawer lg:drawer-open">
|
||||
<AppContextProvider>
|
||||
<Routes>
|
||||
<Route element={<AppLayout />}>
|
||||
<Route path="/chat/:convId" element={<ChatScreen />} />
|
||||
<Route path="*" element={<ChatScreen />} />
|
||||
</Route>
|
||||
</Routes>
|
||||
</AppContextProvider>
|
||||
</div>
|
||||
</HashRouter>
|
||||
<ModalProvider>
|
||||
<HashRouter>
|
||||
<div className="flex flex-row drawer lg:drawer-open">
|
||||
<AppContextProvider>
|
||||
<Routes>
|
||||
<Route element={<AppLayout />}>
|
||||
<Route path="/chat/:convId" element={<ChatScreen />} />
|
||||
<Route path="*" element={<ChatScreen />} />
|
||||
</Route>
|
||||
</Routes>
|
||||
</AppContextProvider>
|
||||
</div>
|
||||
</HashRouter>
|
||||
</ModalProvider>
|
||||
);
|
||||
}
|
||||
|
||||
|
|
151
tools/server/webui/src/components/ModalProvider.tsx
Normal file
151
tools/server/webui/src/components/ModalProvider.tsx
Normal file
|
@ -0,0 +1,151 @@
|
|||
import React, { createContext, useState, useContext } from 'react';
|
||||
|
||||
type ModalContextType = {
|
||||
showConfirm: (message: string) => Promise<boolean>;
|
||||
showPrompt: (
|
||||
message: string,
|
||||
defaultValue?: string
|
||||
) => Promise<string | undefined>;
|
||||
showAlert: (message: string) => Promise<void>;
|
||||
};
|
||||
const ModalContext = createContext<ModalContextType>(null!);
|
||||
|
||||
interface ModalState<T> {
|
||||
isOpen: boolean;
|
||||
message: string;
|
||||
defaultValue?: string;
|
||||
resolve: ((value: T) => void) | null;
|
||||
}
|
||||
|
||||
export function ModalProvider({ children }: { children: React.ReactNode }) {
|
||||
const [confirmState, setConfirmState] = useState<ModalState<boolean>>({
|
||||
isOpen: false,
|
||||
message: '',
|
||||
resolve: null,
|
||||
});
|
||||
const [promptState, setPromptState] = useState<
|
||||
ModalState<string | undefined>
|
||||
>({ isOpen: false, message: '', resolve: null });
|
||||
const [alertState, setAlertState] = useState<ModalState<void>>({
|
||||
isOpen: false,
|
||||
message: '',
|
||||
resolve: null,
|
||||
});
|
||||
const inputRef = React.useRef<HTMLInputElement>(null);
|
||||
|
||||
const showConfirm = (message: string): Promise<boolean> => {
|
||||
return new Promise((resolve) => {
|
||||
setConfirmState({ isOpen: true, message, resolve });
|
||||
});
|
||||
};
|
||||
|
||||
const showPrompt = (
|
||||
message: string,
|
||||
defaultValue?: string
|
||||
): Promise<string | undefined> => {
|
||||
return new Promise((resolve) => {
|
||||
setPromptState({ isOpen: true, message, defaultValue, resolve });
|
||||
});
|
||||
};
|
||||
|
||||
const showAlert = (message: string): Promise<void> => {
|
||||
return new Promise((resolve) => {
|
||||
setAlertState({ isOpen: true, message, resolve });
|
||||
});
|
||||
};
|
||||
|
||||
const handleConfirm = (result: boolean) => {
|
||||
confirmState.resolve?.(result);
|
||||
setConfirmState({ isOpen: false, message: '', resolve: null });
|
||||
};
|
||||
|
||||
const handlePrompt = (result?: string) => {
|
||||
promptState.resolve?.(result);
|
||||
setPromptState({ isOpen: false, message: '', resolve: null });
|
||||
};
|
||||
|
||||
const handleAlertClose = () => {
|
||||
alertState.resolve?.();
|
||||
setAlertState({ isOpen: false, message: '', resolve: null });
|
||||
};
|
||||
|
||||
return (
|
||||
<ModalContext.Provider value={{ showConfirm, showPrompt, showAlert }}>
|
||||
{children}
|
||||
|
||||
{/* Confirm Modal */}
|
||||
{confirmState.isOpen && (
|
||||
<dialog className="modal modal-open z-[1100]">
|
||||
<div className="modal-box">
|
||||
<h3 className="font-bold text-lg">{confirmState.message}</h3>
|
||||
<div className="modal-action">
|
||||
<button
|
||||
className="btn btn-ghost"
|
||||
onClick={() => handleConfirm(false)}
|
||||
>
|
||||
Cancel
|
||||
</button>
|
||||
<button
|
||||
className="btn btn-error"
|
||||
onClick={() => handleConfirm(true)}
|
||||
>
|
||||
Confirm
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</dialog>
|
||||
)}
|
||||
|
||||
{/* Prompt Modal */}
|
||||
{promptState.isOpen && (
|
||||
<dialog className="modal modal-open z-[1100]">
|
||||
<div className="modal-box">
|
||||
<h3 className="font-bold text-lg">{promptState.message}</h3>
|
||||
<input
|
||||
type="text"
|
||||
className="input input-bordered w-full mt-2"
|
||||
defaultValue={promptState.defaultValue}
|
||||
ref={inputRef}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === 'Enter') {
|
||||
handlePrompt((e.target as HTMLInputElement).value);
|
||||
}
|
||||
}}
|
||||
/>
|
||||
<div className="modal-action">
|
||||
<button className="btn btn-ghost" onClick={() => handlePrompt()}>
|
||||
Cancel
|
||||
</button>
|
||||
<button
|
||||
className="btn btn-primary"
|
||||
onClick={() => handlePrompt(inputRef.current?.value)}
|
||||
>
|
||||
Submit
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</dialog>
|
||||
)}
|
||||
|
||||
{/* Alert Modal */}
|
||||
{alertState.isOpen && (
|
||||
<dialog className="modal modal-open z-[1100]">
|
||||
<div className="modal-box">
|
||||
<h3 className="font-bold text-lg">{alertState.message}</h3>
|
||||
<div className="modal-action">
|
||||
<button className="btn" onClick={handleAlertClose}>
|
||||
OK
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</dialog>
|
||||
)}
|
||||
</ModalContext.Provider>
|
||||
);
|
||||
}
|
||||
|
||||
export function useModals() {
|
||||
const context = useContext(ModalContext);
|
||||
if (!context) throw new Error('useModals must be used within ModalProvider');
|
||||
return context;
|
||||
}
|
|
@ -13,6 +13,7 @@ import {
|
|||
SquaresPlusIcon,
|
||||
} from '@heroicons/react/24/outline';
|
||||
import { OpenInNewTab } from '../utils/common';
|
||||
import { useModals } from './ModalProvider';
|
||||
|
||||
type SettKey = keyof typeof CONFIG_DEFAULT;
|
||||
|
||||
|
@ -282,14 +283,15 @@ export default function SettingDialog({
|
|||
const [localConfig, setLocalConfig] = useState<typeof CONFIG_DEFAULT>(
|
||||
JSON.parse(JSON.stringify(config))
|
||||
);
|
||||
const { showConfirm, showAlert } = useModals();
|
||||
|
||||
const resetConfig = () => {
|
||||
if (window.confirm('Are you sure you want to reset all settings?')) {
|
||||
const resetConfig = async () => {
|
||||
if (await showConfirm('Are you sure you want to reset all settings?')) {
|
||||
setLocalConfig(CONFIG_DEFAULT);
|
||||
}
|
||||
};
|
||||
|
||||
const handleSave = () => {
|
||||
const handleSave = async () => {
|
||||
// copy the local config to prevent direct mutation
|
||||
const newConfig: typeof CONFIG_DEFAULT = JSON.parse(
|
||||
JSON.stringify(localConfig)
|
||||
|
@ -302,14 +304,14 @@ export default function SettingDialog({
|
|||
const mustBeNumeric = isNumeric(CONFIG_DEFAULT[key as SettKey]);
|
||||
if (mustBeString) {
|
||||
if (!isString(value)) {
|
||||
alert(`Value for ${key} must be string`);
|
||||
await showAlert(`Value for ${key} must be string`);
|
||||
return;
|
||||
}
|
||||
} else if (mustBeNumeric) {
|
||||
const trimmedValue = value.toString().trim();
|
||||
const numVal = Number(trimmedValue);
|
||||
if (isNaN(numVal) || !isNumeric(numVal) || trimmedValue.length === 0) {
|
||||
alert(`Value for ${key} must be numeric`);
|
||||
await showAlert(`Value for ${key} must be numeric`);
|
||||
return;
|
||||
}
|
||||
// force conversion to number
|
||||
|
@ -317,7 +319,7 @@ export default function SettingDialog({
|
|||
newConfig[key] = numVal;
|
||||
} else if (mustBeBoolean) {
|
||||
if (!isBoolean(value)) {
|
||||
alert(`Value for ${key} must be boolean`);
|
||||
await showAlert(`Value for ${key} must be boolean`);
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
|
|
|
@ -14,6 +14,7 @@ import {
|
|||
import { BtnWithTooltips } from '../utils/common';
|
||||
import { useAppContext } from '../utils/app.context';
|
||||
import toast from 'react-hot-toast';
|
||||
import { useModals } from './ModalProvider';
|
||||
|
||||
export default function Sidebar() {
|
||||
const params = useParams();
|
||||
|
@ -38,6 +39,7 @@ export default function Sidebar() {
|
|||
StorageUtils.offConversationChanged(handleConversationChange);
|
||||
};
|
||||
}, []);
|
||||
const { showConfirm, showPrompt } = useModals();
|
||||
|
||||
const groupedConv = useMemo(
|
||||
() => groupConversationsByDate(conversations),
|
||||
|
@ -130,7 +132,7 @@ export default function Sidebar() {
|
|||
onSelect={() => {
|
||||
navigate(`/chat/${conv.id}`);
|
||||
}}
|
||||
onDelete={() => {
|
||||
onDelete={async () => {
|
||||
if (isGenerating(conv.id)) {
|
||||
toast.error(
|
||||
'Cannot delete conversation while generating'
|
||||
|
@ -138,7 +140,7 @@ export default function Sidebar() {
|
|||
return;
|
||||
}
|
||||
if (
|
||||
window.confirm(
|
||||
await showConfirm(
|
||||
'Are you sure to delete this conversation?'
|
||||
)
|
||||
) {
|
||||
|
@ -167,14 +169,14 @@ export default function Sidebar() {
|
|||
document.body.removeChild(a);
|
||||
URL.revokeObjectURL(url);
|
||||
}}
|
||||
onRename={() => {
|
||||
onRename={async () => {
|
||||
if (isGenerating(conv.id)) {
|
||||
toast.error(
|
||||
'Cannot rename conversation while generating'
|
||||
);
|
||||
return;
|
||||
}
|
||||
const newName = window.prompt(
|
||||
const newName = await showPrompt(
|
||||
'Enter new name for the conversation',
|
||||
conv.name
|
||||
);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue