mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-10 17:14:36 +00:00
180 lines
No EOL
8.8 KiB
C
180 lines
No EOL
8.8 KiB
C
#ifndef RWKV_H
|
|
#define RWKV_H
|
|
|
|
#include <stddef.h>
|
|
#include <stdint.h>
|
|
#include <stdbool.h>
|
|
|
|
#ifdef RWKV_SHARED
|
|
# if defined(_WIN32) && !defined(__MINGW32__)
|
|
# ifdef RWKV_BUILD
|
|
# define RWKV_API __declspec(dllexport)
|
|
# else
|
|
# define RWKV_API __declspec(dllimport)
|
|
# endif
|
|
# else
|
|
# define RWKV_API __attribute__ ((visibility ("default")))
|
|
# endif
|
|
#else
|
|
# define RWKV_API
|
|
#endif
|
|
|
|
// 'ggmf' in hex.
|
|
#define RWKV_FILE_MAGIC 0x67676d66
|
|
|
|
#define RWKV_FILE_VERSION_0 100
|
|
#define RWKV_FILE_VERSION_1 101
|
|
#define RWKV_FILE_VERSION_MIN RWKV_FILE_VERSION_0
|
|
#define RWKV_FILE_VERSION_MAX RWKV_FILE_VERSION_1
|
|
// Default file version is the latest version.
|
|
#define RWKV_FILE_VERSION RWKV_FILE_VERSION_MAX
|
|
|
|
#ifdef __cplusplus
|
|
extern "C" {
|
|
#endif
|
|
|
|
// Represents an error encountered during a function call.
|
|
// These are flags, so an actual value might contain multiple errors.
|
|
enum rwkv_error_flags {
|
|
RWKV_ERROR_NONE = 0,
|
|
|
|
RWKV_ERROR_ARGS = 1 << 8,
|
|
RWKV_ERROR_FILE = 2 << 8,
|
|
RWKV_ERROR_MODEL = 3 << 8,
|
|
RWKV_ERROR_MODEL_PARAMS = 4 << 8,
|
|
RWKV_ERROR_GRAPH = 5 << 8,
|
|
RWKV_ERROR_CTX = 6 << 8,
|
|
|
|
RWKV_ERROR_ALLOC = 1,
|
|
RWKV_ERROR_FILE_OPEN = 2,
|
|
RWKV_ERROR_FILE_STAT = 3,
|
|
RWKV_ERROR_FILE_READ = 4,
|
|
RWKV_ERROR_FILE_WRITE = 5,
|
|
RWKV_ERROR_FILE_MAGIC = 6,
|
|
RWKV_ERROR_FILE_VERSION = 7,
|
|
RWKV_ERROR_DATA_TYPE = 8,
|
|
RWKV_ERROR_UNSUPPORTED = 9,
|
|
RWKV_ERROR_SHAPE = 10,
|
|
RWKV_ERROR_DIMENSION = 11,
|
|
RWKV_ERROR_KEY = 12,
|
|
RWKV_ERROR_DATA = 13,
|
|
RWKV_ERROR_PARAM_MISSING = 14
|
|
};
|
|
|
|
// RWKV context that can be used for inference.
|
|
// All functions that operate on rwkv_context are thread-safe.
|
|
// rwkv_context can be sent to different threads between calls to rwkv_eval.
|
|
// There is no requirement for rwkv_context to be freed on the creating thread.
|
|
struct rwkv_context;
|
|
|
|
// Sets whether errors are automatically printed to stderr.
|
|
// If this is set to false, you are responsible for calling rwkv_last_error manually if an operation fails.
|
|
// - ctx: the context to suppress error messages for.
|
|
// If NULL, affects model load (rwkv_init_from_file) and quantization (rwkv_quantize_model_file) errors,
|
|
// as well as the default for new context.
|
|
// - print_errors: whether error messages should be automatically printed.
|
|
RWKV_API void rwkv_set_print_errors(struct rwkv_context * ctx, bool print_errors);
|
|
|
|
// Gets whether errors are automatically printed to stderr.
|
|
// - ctx: the context to retrieve the setting for, or NULL for the global setting.
|
|
RWKV_API bool rwkv_get_print_errors(struct rwkv_context * ctx);
|
|
|
|
// Retrieves and clears the error flags.
|
|
// - ctx: the context the retrieve the error for, or NULL for the global error.
|
|
RWKV_API enum rwkv_error_flags rwkv_get_last_error(struct rwkv_context * ctx);
|
|
|
|
// Loads the model from a file and prepares it for inference.
|
|
// Returns NULL on any error.
|
|
// - model_file_path: path to model file in ggml format.
|
|
// - n_threads: count of threads to use, must be positive.
|
|
RWKV_API struct rwkv_context * rwkv_init_from_file(const char * model_file_path, const uint32_t n_threads);
|
|
|
|
// Creates a new context from an existing one.
|
|
// This can allow you to run multiple rwkv_eval's in parallel, without having to load a single model multiple times.
|
|
// Each rwkv_context can have one eval running at a time.
|
|
// Every rwkv_context must be freed using rwkv_free.
|
|
// - ctx: context to be cloned.
|
|
// - n_threads: count of threads to use, must be positive.
|
|
RWKV_API struct rwkv_context * rwkv_clone_context(struct rwkv_context * ctx, const uint32_t n_threads);
|
|
|
|
// Offloads specified count of model layers onto the GPU. Offloaded layers are evaluated using cuBLAS.
|
|
// Returns true if at least one layer was offloaded.
|
|
// If rwkv.cpp was compiled without cuBLAS support, this function is a no-op and always returns false.
|
|
RWKV_API bool rwkv_gpu_offload_layers(struct rwkv_context * ctx, const uint32_t n_layers);
|
|
|
|
// Evaluates the model for a single token.
|
|
// Not thread-safe. For parallel inference, call rwkv_clone_context to create one rwkv_context for each thread.
|
|
// Returns false on any error.
|
|
// You can pass NULL to logits_out whenever logits are not needed. This can improve speed by ~10ms per iteration
|
|
// that you do not calculate logits.
|
|
// - token: next token index, in range 0 <= token < n_vocab.
|
|
// - state_in: FP32 buffer of size rwkv_get_state_len(); or NULL, if this is a first pass.
|
|
// - state_out: FP32 buffer of size rwkv_get_state_len(). This buffer will be written to if non-NULL.
|
|
// - logits_out: FP32 buffer of size rwkv_get_logits_len(). This buffer will be written to if non-NULL.
|
|
RWKV_API bool rwkv_eval(struct rwkv_context *, const int n_threads, const uint32_t token, const float * state_in, float * state_out, float * logits_out);
|
|
|
|
// Evaluates the model for a sequence of tokens.
|
|
// Uses a faster algorithm than rwkv_eval if you do not need the state and logits for every token. Best used with batch sizes of 64 or so.
|
|
// Has to build a computation graph on the first call for a given sequence, but will use this cached graph for subsequent calls of the same sequence length.
|
|
// Not thread-safe. For parallel inference, call rwkv_clone_context to create one rwkv_context for each thread.
|
|
// Returns false on any error.
|
|
// You can pass NULL to logits_out whenever logits are not needed. This can improve speed by ~10ms per iteration
|
|
// that you do not calculate logits.
|
|
// - tokens: pointer to an array of tokens. If NULL, the graph will be built and cached, but not executed: this can be useful for initialization.
|
|
// - sequence_len: number of tokens to read from the array.
|
|
// - state_in: FP32 buffer of size rwkv_get_state_len(), or NULL if this is a first pass.
|
|
// - state_out: FP32 buffer of size rwkv_get_state_len(). This buffer will be written to if non-NULL.
|
|
// - logits_out: FP32 buffer of size rwkv_get_logits_len(). This buffer will be written to if non-NULL.
|
|
RWKV_API bool rwkv_eval_sequence(struct rwkv_context * ctx, const int n_threads, const uint32_t * tokens, size_t sequence_len, const float * state_in, float * state_out, float * logits_out);
|
|
|
|
// Returns the number of tokens in the given model's vocabulary.
|
|
// Useful for telling 20B_tokenizer models (n_vocab = 50277) apart from World models (n_vocab = 65536).
|
|
RWKV_API size_t rwkv_get_n_vocab(const struct rwkv_context * ctx);
|
|
|
|
// Returns the number of elements in the given model's embedding.
|
|
// Useful for reading individual fields of a model's hidden state.
|
|
RWKV_API size_t rwkv_get_n_embed(const struct rwkv_context * ctx);
|
|
|
|
// Returns the number of layers in the given model.
|
|
// Useful for always offloading the entire model to GPU.
|
|
RWKV_API size_t rwkv_get_n_layer(const struct rwkv_context * ctx);
|
|
|
|
// Returns the number of float elements in a complete state for the given model.
|
|
// This is the number of elements you'll need to allocate for a call to rwkv_eval, rwkv_eval_sequence, or rwkv_init_state.
|
|
RWKV_API size_t rwkv_get_state_len(const struct rwkv_context * ctx);
|
|
|
|
// Returns the number of float elements in the logits output of a given model.
|
|
// This is currently always identical to n_vocab.
|
|
RWKV_API size_t rwkv_get_logits_len(const struct rwkv_context * ctx);
|
|
|
|
// Initializes the given state so that passing it to rwkv_eval or rwkv_eval_sequence would be identical to passing NULL.
|
|
// Useful in cases where tracking the first call to these functions may be annoying or expensive.
|
|
// State must be initialized for behavior to be defined, passing a zeroed state to rwkv.cpp functions will result in NaNs.
|
|
// - state: FP32 buffer of size rwkv_get_state_len() to initialize
|
|
RWKV_API void rwkv_init_state(const struct rwkv_context * ctx, float * state);
|
|
|
|
// Frees all allocated memory and the context.
|
|
// Does not need to be called on the same thread that created the rwkv_context.
|
|
RWKV_API void rwkv_free(struct rwkv_context * ctx);
|
|
|
|
// Quantizes FP32 or FP16 model to one of quantized formats.
|
|
// Returns false on any error. Error messages would be printed to stderr.
|
|
// - model_file_path_in: path to model file in ggml format, must be either FP32 or FP16.
|
|
// - model_file_path_out: quantized model will be written here.
|
|
// - format_name: must be one of available format names below.
|
|
// Available format names:
|
|
// - Q4_0
|
|
// - Q4_1
|
|
// - Q5_0
|
|
// - Q5_1
|
|
// - Q8_0
|
|
RWKV_API bool rwkv_quantize_model_file(const char * model_file_path_in, const char * model_file_path_out, const char * format_name);
|
|
|
|
// Returns system information string.
|
|
RWKV_API const char * rwkv_get_system_info_string(void);
|
|
|
|
#ifdef __cplusplus
|
|
}
|
|
#endif
|
|
|
|
#endif |