fixed rwkv, standardized new ctx usage

This commit is contained in:
Concedo 2023-07-10 20:05:53 +08:00
parent 2827920044
commit 523fc3be52
8 changed files with 27 additions and 8 deletions

View file

@ -13,6 +13,8 @@
#include "ggml-opencl.h"
#endif
#include "utils.h"
#include <string>
#include <vector>
#include <cstring>
@ -729,6 +731,7 @@ struct rwkv_context {
float * logits_out = 0; //stores address of output logit buffer
size_t gpu_layers;
std::vector<uint8_t> work_buffer;
};
// https://stackoverflow.com/a/6458689
@ -1627,7 +1630,7 @@ bool rwkv_eval(struct rwkv_context * ctx, const int n_threads, const uint32_t to
ctx->serial_graph.cgraph->n_leafs = ctx->serial_graph.post_logits_leafs;
}
ggml_graph_compute_with_ctx(ctx->serial_graph.ctx.ctx, ctx->serial_graph.cgraph.get(),n_threads);
kcpp_graph_compute_helper(ctx->serial_graph.cgraph.get(),n_threads);
rwkv_get_outputs(ctx, state_out, logits_out);
return true;
@ -1715,7 +1718,7 @@ bool rwkv_eval_sequence(struct rwkv_context * ctx, const int n_threads, const ui
ctx->sequence_graph.cgraph->n_leafs = ctx->sequence_graph.post_logits_leafs;
}
ggml_graph_compute_with_ctx(ctx->sequence_graph.ctx.ctx, ctx->sequence_graph.cgraph.get(),n_threads);
kcpp_graph_compute_helper(ctx->sequence_graph.cgraph.get(),n_threads);
rwkv_get_outputs(ctx, state_out, logits_out);
}