mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-17 04:19:40 +00:00
fixed rwkv, standardized new ctx usage
This commit is contained in:
parent
2827920044
commit
523fc3be52
8 changed files with 27 additions and 8 deletions
|
@ -13,6 +13,8 @@
|
|||
#include "ggml-opencl.h"
|
||||
#endif
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <cstring>
|
||||
|
@ -729,6 +731,7 @@ struct rwkv_context {
|
|||
float * logits_out = 0; //stores address of output logit buffer
|
||||
|
||||
size_t gpu_layers;
|
||||
std::vector<uint8_t> work_buffer;
|
||||
};
|
||||
|
||||
// https://stackoverflow.com/a/6458689
|
||||
|
@ -1627,7 +1630,7 @@ bool rwkv_eval(struct rwkv_context * ctx, const int n_threads, const uint32_t to
|
|||
ctx->serial_graph.cgraph->n_leafs = ctx->serial_graph.post_logits_leafs;
|
||||
}
|
||||
|
||||
ggml_graph_compute_with_ctx(ctx->serial_graph.ctx.ctx, ctx->serial_graph.cgraph.get(),n_threads);
|
||||
kcpp_graph_compute_helper(ctx->serial_graph.cgraph.get(),n_threads);
|
||||
rwkv_get_outputs(ctx, state_out, logits_out);
|
||||
|
||||
return true;
|
||||
|
@ -1715,7 +1718,7 @@ bool rwkv_eval_sequence(struct rwkv_context * ctx, const int n_threads, const ui
|
|||
ctx->sequence_graph.cgraph->n_leafs = ctx->sequence_graph.post_logits_leafs;
|
||||
}
|
||||
|
||||
ggml_graph_compute_with_ctx(ctx->sequence_graph.ctx.ctx, ctx->sequence_graph.cgraph.get(),n_threads);
|
||||
kcpp_graph_compute_helper(ctx->sequence_graph.cgraph.get(),n_threads);
|
||||
rwkv_get_outputs(ctx, state_out, logits_out);
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue