hope i didnt break anything

This commit is contained in:
Concedo 2025-08-14 21:42:24 +08:00
commit 7ac0102ed3
45 changed files with 2114 additions and 1090 deletions

View file

@ -0,0 +1,43 @@
name: Build on RISCV Linux Machine by Cloud-V
on:
workflow_dispatch:
workflow_call:
jobs:
bianbu-riscv64-native: # Bianbu 2.2
runs-on: self-hosted
steps:
- name: Install prerequisites
run: |
sudo apt-get update || true
sudo apt-get install -y libatomic1
- uses: actions/checkout@v4
- name: Setup Riscv
run: |
sudo apt-get update || true
sudo apt-get install -y --no-install-recommends \
build-essential \
gcc-14-riscv64-linux-gnu \
g++-14-riscv64-linux-gnu \
cmake
- name: Build
run: |
cmake -B build -DLLAMA_CURL=OFF \
-DCMAKE_BUILD_TYPE=Release \
-DGGML_OPENMP=OFF \
-DLLAMA_BUILD_EXAMPLES=ON \
-DLLAMA_BUILD_TOOLS=ON \
-DLLAMA_BUILD_TESTS=OFF \
-DCMAKE_SYSTEM_NAME=Linux \
-DCMAKE_SYSTEM_PROCESSOR=riscv64 \
-DCMAKE_C_COMPILER=riscv64-linux-gnu-gcc-14 \
-DCMAKE_CXX_COMPILER=riscv64-linux-gnu-g++-14 \
-DCMAKE_POSITION_INDEPENDENT_CODE=ON \
-DCMAKE_FIND_ROOT_PATH=/usr/lib/riscv64-linux-gnu \
-DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \
-DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \
-DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH
cmake --build build --config Release -j $(nproc)

View file

@ -0,0 +1,53 @@
name: "Copilot Setup Steps"
# Automatically run the setup steps when they are changed to allow for easy validation, and
# allow manual testing through the repository's "Actions" tab
on:
workflow_dispatch:
push:
paths:
- .github/workflows/copilot-setup-steps.yml
pull_request:
paths:
- .github/workflows/copilot-setup-steps.yml
jobs:
# The job MUST be called `copilot-setup-steps` or it will not be picked up by Copilot.
copilot-setup-steps:
runs-on: ubuntu-latest
# Set the permissions to the lowest permissions possible needed for your steps.
# Copilot will be given its own token for its operations.
permissions:
# If you want to clone the repository as part of your setup steps, for example to install dependencies, you'll need the `contents: read` permission. If you don't clone the repository in your setup steps, Copilot will do this for you automatically after the steps complete.
contents: read
# You can define any steps you want, and they will run before the agent starts.
# If you do not check out your code, Copilot will do this for you.
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: ccache
uses: hendrikmuhs/ccache-action@v1.2.16
with:
key: copilot-setup-steps
evict-old-files: 1d
- name: Dependencies
id: depends
run: |
sudo apt-get update
sudo apt-get install build-essential libcurl4-openssl-dev
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.11'
- name: Install Python dependencies
run: |
python3 -m venv .venv
.venv/bin/activate
pip install -r requirements/requirements-all.txt -r tools/server/tests/requirements.txt
pip install flake8 pyright

View file

@ -751,6 +751,39 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string &
// utils // utils
// //
// Helper function to parse tensor buffer override strings
static void parse_tensor_buffer_overrides(const std::string & value, std::vector<llama_model_tensor_buft_override> & overrides) {
std::map<std::string, ggml_backend_buffer_type_t> buft_list;
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
auto * dev = ggml_backend_dev_get(i);
auto * buft = ggml_backend_dev_buffer_type(dev);
if (buft) {
buft_list[ggml_backend_buft_name(buft)] = buft;
}
}
for (const auto & override : string_split<std::string>(value, ',')) {
std::string::size_type pos = override.find('=');
if (pos == std::string::npos) {
throw std::invalid_argument("invalid value");
}
std::string tensor_name = override.substr(0, pos);
std::string buffer_type = override.substr(pos + 1);
if (buft_list.find(buffer_type) == buft_list.end()) {
printf("Available buffer types:\n");
for (const auto & it : buft_list) {
printf(" %s\n", ggml_backend_buft_name(it.second));
}
throw std::invalid_argument("unknown buffer type");
}
// keep strings alive and avoid leaking memory by storing them in a static vector
static std::list<std::string> buft_overrides;
buft_overrides.push_back(tensor_name);
overrides.push_back({buft_overrides.back().c_str(), buft_list.at(buffer_type)});
}
}
struct handle_model_result { struct handle_model_result {
bool found_mmproj = false; bool found_mmproj = false;
common_params_model mmproj; common_params_model mmproj;
@ -995,6 +1028,10 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
params.tensor_buft_overrides.push_back({nullptr, nullptr}); params.tensor_buft_overrides.push_back({nullptr, nullptr});
} }
if (!params.speculative.tensor_buft_overrides.empty()) {
params.speculative.tensor_buft_overrides.push_back({nullptr, nullptr});
}
if (!params.chat_template.empty() && !common_chat_verify_template(params.chat_template, params.use_jinja)) { if (!params.chat_template.empty() && !common_chat_verify_template(params.chat_template, params.use_jinja)) {
throw std::runtime_error(string_format( throw std::runtime_error(string_format(
"error: the supplied chat template is not supported: %s%s\n", "error: the supplied chat template is not supported: %s%s\n",
@ -1203,6 +1240,7 @@ bool common_params_parse(int argc, char ** argv, common_params & params, llama_e
common_params_print_completion(ctx_arg); common_params_print_completion(ctx_arg);
exit(0); exit(0);
} }
params.lr.init();
} catch (const std::invalid_argument & ex) { } catch (const std::invalid_argument & ex) {
fprintf(stderr, "%s\n", ex.what()); fprintf(stderr, "%s\n", ex.what());
ctx_arg.params = params_org; ctx_arg.params = params_org;
@ -1471,6 +1509,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.swa_full = true; params.swa_full = true;
} }
).set_env("LLAMA_ARG_SWA_FULL")); ).set_env("LLAMA_ARG_SWA_FULL"));
add_opt(common_arg(
{"--swa-checkpoints"}, "N",
string_format("max number of SWA checkpoints per slot to create (default: %d)\n"
"[(more info)](https://github.com/ggml-org/llama.cpp/pull/15293)", params.n_swa_checkpoints),
[](common_params & params, int value) {
params.n_swa_checkpoints = value;
}
).set_env("LLAMA_ARG_SWA_CHECKPOINTS").set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg( add_opt(common_arg(
{"--kv-unified", "-kvu"}, {"--kv-unified", "-kvu"},
string_format("use single unified KV buffer for the KV cache of all sequences (default: %s)\n" string_format("use single unified KV buffer for the KV cache of all sequences (default: %s)\n"
@ -2351,40 +2397,15 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
add_opt(common_arg( add_opt(common_arg(
{"--override-tensor", "-ot"}, "<tensor name pattern>=<buffer type>,...", {"--override-tensor", "-ot"}, "<tensor name pattern>=<buffer type>,...",
"override tensor buffer type", [](common_params & params, const std::string & value) { "override tensor buffer type", [](common_params & params, const std::string & value) {
/* static */ std::map<std::string, ggml_backend_buffer_type_t> buft_list; parse_tensor_buffer_overrides(value, params.tensor_buft_overrides);
if (buft_list.empty()) {
// enumerate all the devices and add their buffer types to the list
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
auto * dev = ggml_backend_dev_get(i);
auto * buft = ggml_backend_dev_buffer_type(dev);
if (buft) {
buft_list[ggml_backend_buft_name(buft)] = buft;
}
}
}
for (const auto & override : string_split<std::string>(value, ',')) {
std::string::size_type pos = override.find('=');
if (pos == std::string::npos) {
throw std::invalid_argument("invalid value");
}
std::string tensor_name = override.substr(0, pos);
std::string buffer_type = override.substr(pos + 1);
if (buft_list.find(buffer_type) == buft_list.end()) {
printf("Available buffer types:\n");
for (const auto & it : buft_list) {
printf(" %s\n", ggml_backend_buft_name(it.second));
}
throw std::invalid_argument("unknown buffer type");
}
// keep strings alive and avoid leaking memory by storing them in a static vector
static std::list<std::string> buft_overrides;
buft_overrides.push_back(tensor_name);
params.tensor_buft_overrides.push_back({buft_overrides.back().c_str(), buft_list.at(buffer_type)});
}
} }
)); ));
add_opt(common_arg(
{"--override-tensor-draft", "-otd"}, "<tensor name pattern>=<buffer type>,...",
"override tensor buffer type for draft model", [](common_params & params, const std::string & value) {
parse_tensor_buffer_overrides(value, params.speculative.tensor_buft_overrides);
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg( add_opt(common_arg(
{"--cpu-moe", "-cmoe"}, {"--cpu-moe", "-cmoe"},
"keep all Mixture of Experts (MoE) weights in the CPU", "keep all Mixture of Experts (MoE) weights in the CPU",
@ -2407,6 +2428,27 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
} }
} }
).set_env("LLAMA_ARG_N_CPU_MOE")); ).set_env("LLAMA_ARG_N_CPU_MOE"));
add_opt(common_arg(
{"--cpu-moe-draft", "-cmoed"},
"keep all Mixture of Experts (MoE) weights in the CPU for the draft model",
[](common_params & params) {
params.speculative.tensor_buft_overrides.push_back({"\\.ffn_(up|down|gate)_exps", ggml_backend_cpu_buffer_type()});
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CPU_MOE_DRAFT"));
add_opt(common_arg(
{"--n-cpu-moe-draft", "-ncmoed"}, "N",
"keep the Mixture of Experts (MoE) weights of the first N layers in the CPU for the draft model",
[](common_params & params, int value) {
if (value < 0) {
throw std::invalid_argument("invalid value");
}
for (int i = 0; i < value; ++i) {
static std::list<std::string> buft_overrides_draft;
buft_overrides_draft.push_back(string_format("blk\\.%d\\.ffn_(up|down|gate)_exps", i));
params.speculative.tensor_buft_overrides.push_back({buft_overrides_draft.back().c_str(), ggml_backend_cpu_buffer_type()});
}
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_N_CPU_MOE_DRAFT"));
add_opt(common_arg( add_opt(common_arg(
{"-ngl", "--gpu-layers", "--n-gpu-layers"}, "N", {"-ngl", "--gpu-layers", "--n-gpu-layers"}, "N",
"number of layers to store in VRAM", "number of layers to store in VRAM",
@ -2657,7 +2699,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
[](common_params & params, const std::string & value) { [](common_params & params, const std::string & value) {
params.out_file = value; params.out_file = value;
} }
).set_examples({LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_CVECTOR_GENERATOR, LLAMA_EXAMPLE_EXPORT_LORA, LLAMA_EXAMPLE_TTS})); ).set_examples({LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_CVECTOR_GENERATOR, LLAMA_EXAMPLE_EXPORT_LORA, LLAMA_EXAMPLE_TTS, LLAMA_EXAMPLE_FINETUNE}));
add_opt(common_arg( add_opt(common_arg(
{"-ofreq", "--output-frequency"}, "N", {"-ofreq", "--output-frequency"}, "N",
string_format("output the imatrix every N iterations (default: %d)", params.n_out_freq), string_format("output the imatrix every N iterations (default: %d)", params.n_out_freq),
@ -3132,7 +3174,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.speculative.cpuparams.n_threads = std::thread::hardware_concurrency(); params.speculative.cpuparams.n_threads = std::thread::hardware_concurrency();
} }
} }
).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg( add_opt(common_arg(
{"-tbd", "--threads-batch-draft"}, "N", {"-tbd", "--threads-batch-draft"}, "N",
"number of threads to use during batch and prompt processing (default: same as --threads-draft)", "number of threads to use during batch and prompt processing (default: same as --threads-draft)",
@ -3142,7 +3184,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.speculative.cpuparams_batch.n_threads = std::thread::hardware_concurrency(); params.speculative.cpuparams_batch.n_threads = std::thread::hardware_concurrency();
} }
} }
).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg( add_opt(common_arg(
{"-Cd", "--cpu-mask-draft"}, "M", {"-Cd", "--cpu-mask-draft"}, "M",
"Draft model CPU affinity mask. Complements cpu-range-draft (default: same as --cpu-mask)", "Draft model CPU affinity mask. Complements cpu-range-draft (default: same as --cpu-mask)",
@ -3535,5 +3577,51 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
).set_examples({ LLAMA_EXAMPLE_DIFFUSION })); ).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
add_opt(
common_arg({ "-lr", "--learning-rate" }, "ALPHA",
string_format(
"adamw or sgd optimizer alpha (default: %.2g); note: sgd alpha recommended ~10x (no momentum)",
(double) params.lr.lr0),
[](common_params & params, const std::string & value) { params.lr.lr0 = std::stof(value); })
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
add_opt(
common_arg({ "-lr-min", "--learning-rate-min" }, "ALPHA",
string_format(
"(if >0) final learning rate after decay (if -decay-epochs is set, default=%.2g)",
(double) params.lr.lr_min),
[](common_params & params, const std::string & value) { params.lr.lr_min = std::stof(value); })
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
add_opt(
common_arg({ "-decay-epochs", "--learning-rate-decay-epochs" }, "ALPHA",
string_format(
"(if >0) decay learning rate to -lr-min after this many epochs (exponential decay, default=%.2g)",
(double) params.lr.decay_epochs),
[](common_params & params, const std::string & value) { params.lr.decay_epochs = std::stof(value); })
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
add_opt(common_arg(
{ "-wd", "--weight-decay" }, "WD",
string_format(
"adamw or sgd optimizer weight decay (0 is off; recommend very small e.g. 1e-9) (default: %.2g).",
(double) params.lr.wd),
[](common_params & params, const std::string & value) { params.lr.wd = std::stof(value); })
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
add_opt(common_arg({ "-val-split", "--val-split" }, "FRACTION",
string_format("fraction of data to use as validation set for training (default: %.2g).",
(double) params.val_split),
[](common_params & params, const std::string & value) { params.val_split = std::stof(value); })
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
add_opt(common_arg({ "-epochs", "--epochs" }, "N",
string_format("optimizer max # of epochs (default: %d)", params.lr.epochs),
[](common_params & params, int epochs) { params.lr.epochs = epochs; })
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
add_opt(common_arg({ "-opt", "--optimizer" }, "sgd|adamw", "adamw or sgd",
[](common_params & params, const std::string & name) {
params.optimizer = common_opt_get_optimizer(name.c_str());
if (params.optimizer == GGML_OPT_OPTIMIZER_TYPE_COUNT) {
throw std::invalid_argument("invalid --optimizer, valid options: adamw, sgd");
}
})
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
return ctx_arg; return ctx_arg;
} }

View file

@ -49,6 +49,7 @@
#endif #endif
#include <locale> #include <locale>
#include <windows.h> #include <windows.h>
#include <string.h>
#include <fcntl.h> #include <fcntl.h>
#include <io.h> #include <io.h>
#else #else
@ -1573,3 +1574,56 @@ ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std
return result; return result;
} }
ggml_opt_optimizer_params common_opt_lr_pars(void * userdata) {
ggml_opt_optimizer_params result = ggml_opt_get_default_optimizer_params(nullptr);
const lr_opt & d = *(lr_opt *) userdata;
result.adamw.alpha = result.sgd.alpha = d.get_lr(d.epoch);
result.sgd.wd = result.adamw.wd = d.wd;
return result;
}
// TODO make all command line args case-insensitive
static inline bool eq_case_insensitive(char const* a, char const* b) {
return !
#if defined(_MSC_VER)
_stricmp
#else
strcasecmp
#endif // defined(_MSC_VER)
(a, b);
}
enum ggml_opt_optimizer_type common_opt_get_optimizer(const char * n) {
if (eq_case_insensitive("adamw", n)) {
return GGML_OPT_OPTIMIZER_TYPE_ADAMW;
}
if (eq_case_insensitive("sgd", n)) {
return GGML_OPT_OPTIMIZER_TYPE_SGD;
}
return GGML_OPT_OPTIMIZER_TYPE_COUNT;
}
// TODO simplify to use just log and exp
static float const k_log_2 = std::log(2.f);
void lr_opt::init() {
if (lr_min > 0 && lr_min < lr0) {
float nhalf = std::log(lr0 / lr_min) / k_log_2;
float e = epochs;
if (decay_epochs > 0 && decay_epochs < e) {
e = decay_epochs;
} else {
decay_epochs = e;
}
scale_epoch = nhalf / e;
}
}
float lr_opt::get_lr(float epoch) const {
float r = lr_min <= 0 ? lr0 :
epoch >= decay_epochs ? lr_min :
lr0 * std::pow(0.5f, epoch * scale_epoch);
LOG_INF("epoch %.2g lr=%.2g\n", epoch, r);
return r;
}

View file

@ -2,14 +2,17 @@
#pragma once #pragma once
#include "llama-cpp.h"
#include <set> #include <set>
#include <sstream>
#include <string> #include <string>
#include <string_view> #include <string_view>
#include <vector> #include <vector>
#include <map> #include <map>
#include <sstream> #include <sstream>
#include <cmath>
#include "ggml-opt.h"
#include "llama-cpp.h"
#ifdef _WIN32 #ifdef _WIN32
#define DIRECTORY_SEPARATOR '\\' #define DIRECTORY_SEPARATOR '\\'
@ -78,6 +81,7 @@ enum llama_example {
LLAMA_EXAMPLE_PARALLEL, LLAMA_EXAMPLE_PARALLEL,
LLAMA_EXAMPLE_TTS, LLAMA_EXAMPLE_TTS,
LLAMA_EXAMPLE_DIFFUSION, LLAMA_EXAMPLE_DIFFUSION,
LLAMA_EXAMPLE_FINETUNE,
LLAMA_EXAMPLE_COUNT, LLAMA_EXAMPLE_COUNT,
}; };
@ -198,6 +202,7 @@ struct common_params_speculative {
float p_split = 0.1f; // speculative decoding split probability float p_split = 0.1f; // speculative decoding split probability
float p_min = 0.75f; // minimum speculative decoding probability (greedy) float p_min = 0.75f; // minimum speculative decoding probability (greedy)
std::vector<std::pair<std::string, std::string>> replacements; // main to speculative model replacements std::vector<std::pair<std::string, std::string>> replacements; // main to speculative model replacements
std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;
ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K
ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V
@ -238,6 +243,25 @@ enum common_reasoning_format {
COMMON_REASONING_FORMAT_GRANITE, // Extract thinking tag contents and return as `message.reasoning_content`, including in streaming deltas. COMMON_REASONING_FORMAT_GRANITE, // Extract thinking tag contents and return as `message.reasoning_content`, including in streaming deltas.
}; };
struct lr_opt {
float lr0 = 1e-5; // learning rate at first epoch
float lr_min = -1;
float decay_epochs = -1; // if >0, the learning rate starts at lr0 and decays to lr_min after this many epochs
float scale_epoch = 0;
float wd = 0;
unsigned epochs = 2;
unsigned epoch; // set by optimizer outer (epochs) loop
// learning rate decay - constant LR per epoch only for now
float get_lr(float e) const;
float get_lr() const { return get_lr(epoch); }
// must call after arg parse, before get_lr
void init();
};
struct ggml_opt_optimizer_params common_opt_lr_pars(void * userdata);
struct common_params { struct common_params {
int32_t n_predict = -1; // new tokens to predict int32_t n_predict = -1; // new tokens to predict
int32_t n_ctx = 4096; // context size int32_t n_ctx = 4096; // context size
@ -372,6 +396,11 @@ struct common_params {
bool no_mmproj = false; // explicitly disable multimodal model bool no_mmproj = false; // explicitly disable multimodal model
std::vector<std::string> image; // path to image file(s) std::vector<std::string> image; // path to image file(s)
// finetune
struct lr_opt lr;
enum ggml_opt_optimizer_type optimizer = GGML_OPT_OPTIMIZER_TYPE_ADAMW;
float val_split = 0.05f; // fraction of the data used for the validation set
// embedding // embedding
bool embedding = false; // get only sentence embedding bool embedding = false; // get only sentence embedding
int32_t embd_normalize = 2; // normalisation for embeddings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm) int32_t embd_normalize = 2; // normalisation for embeddings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)
@ -385,6 +414,7 @@ struct common_params {
int32_t timeout_write = timeout_read; // http write timeout in seconds int32_t timeout_write = timeout_read; // http write timeout in seconds
int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool) int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool)
int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting
int32_t n_swa_checkpoints = 3; // max number of SWA checkpoints per slot
std::string hostname = "127.0.0.1"; std::string hostname = "127.0.0.1";
std::string public_path = ""; // NOLINT std::string public_path = ""; // NOLINT
@ -699,3 +729,6 @@ const char * const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count";
// //
ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std::vector<llama_token> & tokens, int64_t stride); ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std::vector<llama_token> & tokens, int64_t stride);
// "adamw" or "sgd" (case insensitive)
enum ggml_opt_optimizer_type common_opt_get_optimizer(const char *);

View file

@ -74,16 +74,26 @@ extern "C" {
GGML_OPT_BUILD_TYPE_OPT = 30, GGML_OPT_BUILD_TYPE_OPT = 30,
}; };
enum ggml_opt_optimizer_type {
GGML_OPT_OPTIMIZER_TYPE_ADAMW,
GGML_OPT_OPTIMIZER_TYPE_SGD,
GGML_OPT_OPTIMIZER_TYPE_COUNT
};
// parameters that control which optimizer is used and how said optimizer tries to find the minimal loss // parameters that control which optimizer is used and how said optimizer tries to find the minimal loss
struct ggml_opt_optimizer_params { struct ggml_opt_optimizer_params {
// AdamW optimizer parameters
struct { struct {
float alpha; // learning rate float alpha; // learning rate
float beta1; float beta1; // first AdamW momentum
float beta2; float beta2; // second AdamW momentum
float eps; // epsilon for numerical stability float eps; // epsilon for numerical stability
float wd; // weight decay for AdamW, use 0.0f to disable float wd; // weight decay - 0.0f to disable
} adamw; } adamw;
struct {
float alpha; // learning rate
float wd; // weight decay
} sgd;
}; };
// callback to calculate optimizer parameters prior to a backward pass // callback to calculate optimizer parameters prior to a backward pass
@ -114,6 +124,9 @@ extern "C" {
ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters
void * get_opt_pars_ud; // userdata for calculating optimizer parameters void * get_opt_pars_ud; // userdata for calculating optimizer parameters
// only GGML_OPT_OPTIMIZER_TYPE_ADAMW needs m, v momenta per parameter tensor
enum ggml_opt_optimizer_type optimizer;
}; };
// get parameters for an optimization context with defaults set where possible // get parameters for an optimization context with defaults set where possible
@ -142,6 +155,10 @@ extern "C" {
// get the gradient accumulator for a node from the forward graph // get the gradient accumulator for a node from the forward graph
GGML_API struct ggml_tensor * ggml_opt_grad_acc(ggml_opt_context_t opt_ctx, struct ggml_tensor * node); GGML_API struct ggml_tensor * ggml_opt_grad_acc(ggml_opt_context_t opt_ctx, struct ggml_tensor * node);
GGML_API enum ggml_opt_optimizer_type ggml_opt_context_optimizer_type(ggml_opt_context_t); //TODO consistent naming scheme
GGML_API const char * ggml_opt_optimizer_name(enum ggml_opt_optimizer_type);
// ====== Optimization Result ====== // ====== Optimization Result ======
GGML_API ggml_opt_result_t ggml_opt_result_init(void); GGML_API ggml_opt_result_t ggml_opt_result_init(void);
@ -226,12 +243,14 @@ extern "C" {
struct ggml_tensor * outputs, // output tensor, must have shape [ne_label, ndata_batch] if labels are used struct ggml_tensor * outputs, // output tensor, must have shape [ne_label, ndata_batch] if labels are used
ggml_opt_dataset_t dataset, // dataset with data and optionally also labels ggml_opt_dataset_t dataset, // dataset with data and optionally also labels
enum ggml_opt_loss_type loss_type, // loss to minimize enum ggml_opt_loss_type loss_type, // loss to minimize
enum ggml_opt_optimizer_type optimizer, // sgd or adamw
ggml_opt_get_optimizer_params get_opt_pars, // callback to get optimizer params, userdata is pointer to epoch (of type int64_t) ggml_opt_get_optimizer_params get_opt_pars, // callback to get optimizer params, userdata is pointer to epoch (of type int64_t)
int64_t nepoch, // how many times the dataset should be iterated over int64_t nepoch, // how many times the dataset should be iterated over
int64_t nbatch_logical, // datapoints optimizer step, must be a multiple of ndata_batch in inputs/outputs int64_t nbatch_logical, // datapoints optimizer step, must be a multiple of ndata_batch in inputs/outputs
float val_split, // fraction of the dataset to use for validation, must be in [0.0f, 1.0f) float val_split, // fraction of the dataset to use for validation, must be in [0.0f, 1.0f)
bool silent); // whether or not info prints to stderr should be suppressed bool silent); // whether or not info prints to stderr should be suppressed
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif

View file

@ -241,6 +241,8 @@
#define GGML_ROPE_TYPE_MROPE 8 #define GGML_ROPE_TYPE_MROPE 8
#define GGML_ROPE_TYPE_VISION 24 #define GGML_ROPE_TYPE_VISION 24
#define GGML_MROPE_SECTIONS 4
#define GGML_UNUSED(x) (void)(x) #define GGML_UNUSED(x) (void)(x)
#define GGML_PAD(x, n) (((x) + (n) - 1) & ~((n) - 1)) #define GGML_PAD(x, n) (((x) + (n) - 1) & ~((n) - 1))
@ -546,6 +548,7 @@ extern "C" {
GGML_OP_CROSS_ENTROPY_LOSS, GGML_OP_CROSS_ENTROPY_LOSS,
GGML_OP_CROSS_ENTROPY_LOSS_BACK, GGML_OP_CROSS_ENTROPY_LOSS_BACK,
GGML_OP_OPT_STEP_ADAMW, GGML_OP_OPT_STEP_ADAMW,
GGML_OP_OPT_STEP_SGD,
GGML_OP_GLU, GGML_OP_GLU,
@ -1673,7 +1676,7 @@ extern "C" {
struct ggml_tensor * b, struct ggml_tensor * b,
struct ggml_tensor * c, struct ggml_tensor * c,
int n_dims, int n_dims,
int sections[4], int sections[GGML_MROPE_SECTIONS],
int mode, int mode,
int n_ctx_orig, int n_ctx_orig,
float freq_base, float freq_base,
@ -1699,6 +1702,22 @@ extern "C" {
float beta_fast, float beta_fast,
float beta_slow); float beta_slow);
GGML_API struct ggml_tensor * ggml_rope_multi_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
struct ggml_tensor * c,
int n_dims,
int sections[GGML_MROPE_SECTIONS],
int mode,
int n_ctx_orig,
float freq_base,
float freq_scale,
float ext_factor,
float attn_factor,
float beta_fast,
float beta_slow);
GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_rope_custom( GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_rope_custom(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, struct ggml_tensor * a,
@ -2306,7 +2325,14 @@ extern "C" {
struct ggml_tensor * grad, struct ggml_tensor * grad,
struct ggml_tensor * m, struct ggml_tensor * m,
struct ggml_tensor * v, struct ggml_tensor * v,
struct ggml_tensor * adamw_params); // parameters such a the learning rate struct ggml_tensor * adamw_params); // parameters such as the learning rate
// stochastic gradient descent step (with weight decay)
GGML_API struct ggml_tensor * ggml_opt_step_sgd(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * grad,
struct ggml_tensor * sgd_params); // alpha, weight decay
// //
// automatic differentiation // automatic differentiation

View file

@ -40,18 +40,22 @@
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
#elif defined(__aarch64__) || defined(__arm__) || defined(_M_ARM) || defined(_M_ARM64) #elif defined(__aarch64__) || defined(__arm__) || defined(_M_ARM) || defined(_M_ARM64)
// repack.cpp // repack.cpp
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8 #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
#elif defined(__x86_64__) || defined(__i386__) || defined(_M_IX86) || defined(_M_X64) #elif defined(__x86_64__) || defined(__i386__) || defined(_M_IX86) || defined(_M_X64)
// repack.cpp // repack.cpp
@ -80,12 +84,14 @@
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
#elif defined(__loongarch64) #elif defined(__loongarch64)
// quants.c // quants.c
#define quantize_row_q8_K_generic quantize_row_q8_K #define quantize_row_q8_K_generic quantize_row_q8_K
@ -103,12 +109,14 @@
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
#elif defined(__riscv) #elif defined(__riscv)
// quants.c // quants.c
#define quantize_row_q8_K_generic quantize_row_q8_K #define quantize_row_q8_K_generic quantize_row_q8_K
@ -133,11 +141,13 @@
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
#elif defined(__s390x__) #elif defined(__s390x__)
// quants.c // quants.c
#define quantize_row_q8_K_generic quantize_row_q8_K #define quantize_row_q8_K_generic quantize_row_q8_K
@ -164,12 +174,14 @@
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
#elif defined(__wasm__) #elif defined(__wasm__)
// quants.c // quants.c
#define ggml_vec_dot_q4_1_q8_1_generic ggml_vec_dot_q4_1_q8_1 #define ggml_vec_dot_q4_1_q8_1_generic ggml_vec_dot_q4_1_q8_1
@ -195,10 +207,12 @@
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
#endif #endif

File diff suppressed because it is too large Load diff

View file

@ -2036,6 +2036,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
ggml_compute_forward_opt_step_adamw(params, tensor); ggml_compute_forward_opt_step_adamw(params, tensor);
} }
break; break;
case GGML_OP_OPT_STEP_SGD:
{
ggml_compute_forward_opt_step_sgd(params, tensor);
}
break;
case GGML_OP_NONE: case GGML_OP_NONE:
{ {
// nop // nop
@ -2339,6 +2344,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_OP_CROSS_ENTROPY_LOSS: case GGML_OP_CROSS_ENTROPY_LOSS:
case GGML_OP_CROSS_ENTROPY_LOSS_BACK: case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
case GGML_OP_OPT_STEP_ADAMW: case GGML_OP_OPT_STEP_ADAMW:
case GGML_OP_OPT_STEP_SGD:
{ {
n_tasks = n_threads; n_tasks = n_threads;
} break; } break;

View file

@ -10330,6 +10330,7 @@ static void ggml_compute_forward_opt_step_adamw_f32(
const int ir1 = MIN(ir0 + dr, nr); const int ir1 = MIN(ir0 + dr, nr);
const float * adamw_params_ptr = ggml_get_data_f32(adamw_params); const float * adamw_params_ptr = ggml_get_data_f32(adamw_params);
const float alpha = adamw_params_ptr[0]; const float alpha = adamw_params_ptr[0];
const float beta1 = adamw_params_ptr[1]; const float beta1 = adamw_params_ptr[1];
const float beta2 = adamw_params_ptr[2]; const float beta2 = adamw_params_ptr[2];
@ -10337,7 +10338,7 @@ static void ggml_compute_forward_opt_step_adamw_f32(
const float wd = adamw_params_ptr[4]; const float wd = adamw_params_ptr[4];
const float beta1h = adamw_params_ptr[5]; const float beta1h = adamw_params_ptr[5];
const float beta2h = adamw_params_ptr[6]; const float beta2h = adamw_params_ptr[6];
const float keep = 1.f - alpha * wd;
for (int ir = ir0; ir < ir1; ++ir) { for (int ir = ir0; ir < ir1; ++ir) {
const int64_t i03 = ir/(ne02*ne01); const int64_t i03 = ir/(ne02*ne01);
const int64_t i02 = (ir - i03*ne02*ne01)/ne01; const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
@ -10360,7 +10361,7 @@ static void ggml_compute_forward_opt_step_adamw_f32(
// The weight decay is applied independently of the Adam momenta m and v. // The weight decay is applied independently of the Adam momenta m and v.
// This is NOT equivalent to l2 regularization that adds w[i00]*w[i00] to the loss. // This is NOT equivalent to l2 regularization that adds w[i00]*w[i00] to the loss.
// See: https://arxiv.org/pdf/1711.05101v3.pdf // See: https://arxiv.org/pdf/1711.05101v3.pdf
w[i00] = w[i00]*(1.0f - alpha*wd) - alpha*mh/vh; w[i00] = w[i00] * keep - alpha * mh / vh;
} }
} }
} }
@ -10382,3 +10383,63 @@ void ggml_compute_forward_opt_step_adamw(
} }
} }
} }
static void ggml_compute_forward_opt_step_sgd_f32(const ggml_compute_params * params, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src0_grad = dst->src[1];
const ggml_tensor * sgd_params = dst->src[2];
GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
GGML_ASSERT(ggml_nelements(sgd_params) == 2);
const int ith = params->ith;
const int nth = params->nth;
const int nr = ggml_nrows(src0);
GGML_TENSOR_UNARY_OP_LOCALS
GGML_ASSERT(nb00 == sizeof(float));
// rows per thread
const int dr = (nr + nth - 1) / nth;
// row range for this thread
const int ir0 = dr * ith;
const int ir1 = MIN(ir0 + dr, nr);
// using adamw param subset we care about - alpha, wd - could have a separate struct
const float * sgd_params_ptr = ggml_get_data_f32(sgd_params);
const float alpha = sgd_params_ptr[0];
const float keep = 1.f - alpha * sgd_params_ptr[1];
for (int ir = ir0; ir < ir1; ++ir) {
const int64_t i03 = ir / (ne02 * ne01);
const int64_t i02 = (ir - i03 * ne02 * ne01) / ne01;
const int64_t i01 = (ir - i03 * ne02 * ne01 - i02 * ne01);
const size_t offset = i03 * nb03 + i02 * nb02 + i01 * nb01;
float * w = (float *) ((char *) src0->data + offset); // weight
const float * g = (const float *) ((const char *) src0_grad->data + offset); // grad
for (int i00 = 0; i00 < ne00; ++i00) {
w[i00] = w[i00] * keep - alpha * g[i00];
}
}
}
void ggml_compute_forward_opt_step_sgd(const ggml_compute_params * params, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_opt_step_sgd_f32(params, dst);
}
break;
default:
{
GGML_ABORT("fatal error - sgd is F32 only");
}
}
}

View file

@ -107,7 +107,7 @@ void ggml_compute_forward_cross_entropy_loss(const struct ggml_compute_params *
void ggml_compute_forward_cross_entropy_loss_back(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_cross_entropy_loss_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_opt_step_adamw(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_opt_step_adamw(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_mul_mat(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_mul_mat(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_opt_step_sgd(const struct ggml_compute_params * params, struct ggml_tensor * dst);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif

View file

@ -206,6 +206,7 @@ void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs,
const int ncols_interleaved = 4; const int ncols_interleaved = 4;
const int blocklen = 4; const int blocklen = 4;
assert(nr == 1);
assert(n % qk == 0); assert(n % qk == 0);
assert(nc % ncols_interleaved == 0); assert(nc % ncols_interleaved == 0);
@ -307,7 +308,6 @@ void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs,
UNUSED(ncols_interleaved); UNUSED(ncols_interleaved);
UNUSED(blocklen); UNUSED(blocklen);
{
float sumf[8]; float sumf[8];
int sumi; int sumi;
@ -332,7 +332,6 @@ void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs,
for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
} }
} }
}
void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
const int qk = QK_K; const int qk = QK_K;
@ -494,20 +493,13 @@ void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs
const int ncols_interleaved = 4; const int ncols_interleaved = 4;
const int blocklen = 4; const int blocklen = 4;
assert(nr == 1);
assert(n % qk == 0); assert(n % qk == 0);
assert(nc % ncols_interleaved == 0); assert(nc % ncols_interleaved == 0);
UNUSED(s);
UNUSED(bs); UNUSED(bs);
UNUSED(vx);
UNUSED(vy);
UNUSED(nr); UNUSED(nr);
UNUSED(nc);
UNUSED(nb);
UNUSED(ncols_interleaved);
UNUSED(blocklen);
{
float sumf[4]; float sumf[4];
int sumi; int sumi;
@ -532,6 +524,43 @@ void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs
for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
} }
} }
void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
const int qk = QK8_0;
const int nb = n / qk;
const int ncols_interleaved = 8;
const int blocklen = 8;
assert(nr == 1);
assert(n % qk == 0);
assert(nc % ncols_interleaved == 0);
UNUSED(bs);
UNUSED(nr);
float sumf[8];
int sumi;
const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
for (int x = 0; x < nc / ncols_interleaved; x++) {
const block_iq4_nlx8 * b_ptr = (const block_iq4_nlx8 *) vx + (x * nb);
for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
for (int l = 0; l < nb; l++) {
for (int k = 0; k < (qk / (2 * blocklen)); k++) {
for (int j = 0; j < ncols_interleaved; j++) {
sumi = 0;
for (int i = 0; i < blocklen; ++i) {
const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F];
const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4];
sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2]));
}
sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d);
}
}
}
for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
}
} }
void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
@ -934,6 +963,50 @@ void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs
} }
} }
void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
const int qk = QK8_0;
const int nb = n / qk;
const int ncols_interleaved = 8;
const int blocklen = 8;
assert(n % qk == 0);
assert(nr % 4 == 0);
assert(nc % ncols_interleaved == 0);
float sumf[4][8];
int sumi;
for (int y = 0; y < nr / 4; y++) {
const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
for (int x = 0; x < nc / ncols_interleaved; x++) {
const block_iq4_nlx8 * b_ptr = (const block_iq4_nlx8 *) vx + (x * nb);
for (int m = 0; m < 4; m++) {
for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
}
for (int l = 0; l < nb; l++) {
for (int k = 0; k < (qk / (2 * blocklen)); k++) {
for (int m = 0; m < 4; m++) {
for (int j = 0; j < ncols_interleaved; j++) {
sumi = 0;
for (int i = 0; i < blocklen; ++i) {
const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F];
const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4];
sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
(v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4]));
}
sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]);
}
}
}
}
for (int m = 0; m < 4; m++) {
for (int j = 0; j < ncols_interleaved; j++)
s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
}
}
}
}
} // extern "C" } // extern "C"
static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int blck_size_interleave) { static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int blck_size_interleave) {
@ -1302,15 +1375,16 @@ static int repack_iq4_nl_to_iq4_nl_4_bl(struct ggml_tensor * t, int interleave_b
return 0; return 0;
} }
GGML_ASSERT(t->type == GGML_TYPE_IQ4_NL); GGML_ASSERT(t->type == GGML_TYPE_IQ4_NL);
//GGML_ASSERT(interleave_block == 4 || interleave_block == 8);
GGML_ASSERT(interleave_block == 4); GGML_ASSERT(interleave_block == 4);
block_iq4_nlx4 * dst = (block_iq4_nlx4 *)t->data;
const block_iq4_nl * src = (const block_iq4_nl *)data; const block_iq4_nl * src = (const block_iq4_nl *)data;
block_iq4_nlx4 * dst = ( block_iq4_nlx4 *)t->data;
block_iq4_nl dst_tmp[4]; block_iq4_nl dst_tmp[4];
int nrow = ggml_nrows(t); int nrow = ggml_nrows(t);
int nrows_interleaved = 4; int nrows_interleaved = 4;
int nblocks = t->ne[0] / QK4_0; int nblocks = t->ne[0] / QK4_NL;
GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_iq4_nl)); GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_iq4_nl));
@ -1332,6 +1406,63 @@ static int repack_iq4_nl_to_iq4_nl_4_bl(struct ggml_tensor * t, int interleave_b
GGML_UNUSED(data_size); GGML_UNUSED(data_size);
} }
static block_iq4_nlx8 make_block_iq4_nlx8(block_iq4_nl * in, unsigned int blck_size_interleave) {
block_iq4_nlx8 out;
for (int i = 0; i < 8; i++) {
out.d[i] = in[i].d;
}
const int end = QK4_NL * 4 / blck_size_interleave;
if (blck_size_interleave == 8) {
for (int i = 0; i < end; ++i) {
int src_id = i % 8;
int src_offset = (i / 8) * blck_size_interleave;
int dst_offset = i * blck_size_interleave;
memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint64_t));
}
} else {
GGML_ASSERT(false);
}
return out;
}
static int repack_iq4_nl_to_iq4_nl_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
GGML_ASSERT(t->type == GGML_TYPE_IQ4_NL);
GGML_ASSERT(interleave_block == 8);
const block_iq4_nl * src = (const block_iq4_nl *)data;
block_iq4_nlx8 * dst = ( block_iq4_nlx8 *)t->data;
block_iq4_nl dst_tmp[8];
int nrow = ggml_nrows(t);
int nrows_interleaved = 8;
int nblocks = t->ne[0] / QK4_NL;
GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_iq4_nl));
if (t->ne[1] % nrows_interleaved != 0) {
return -1;
}
for (int b = 0; b < nrow; b += nrows_interleaved) {
for (int64_t x = 0; x < nblocks; x++) {
for (int i = 0; i < nrows_interleaved; i++) {
dst_tmp[i] = src[x + i * nblocks];
}
*dst++ = make_block_iq4_nlx8(dst_tmp, interleave_block);
}
src += nrows_interleaved * nblocks;
}
return 0;
GGML_UNUSED(data_size);
}
namespace ggml::cpu::repack { namespace ggml::cpu::repack {
// repack // repack
template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS> template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS>
@ -1367,6 +1498,10 @@ template <> int repack<block_iq4_nl, 4, 4>(struct ggml_tensor * t, const void *
// return repack_iq4_nl_to_iq4_nl_4_bl(t, 8, data, data_size); // return repack_iq4_nl_to_iq4_nl_4_bl(t, 8, data, data_size);
//} //}
template <> int repack<block_iq4_nl, 8, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {
return repack_iq4_nl_to_iq4_nl_8_bl(t, 8, data, data_size);
}
// gemv // gemv
template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PARAM_TYPE> template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PARAM_TYPE>
void gemv(int, float *, size_t, const void *, const void *, int, int); void gemv(int, float *, size_t, const void *, const void *, int, int);
@ -1395,6 +1530,10 @@ template <> void gemv<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size
ggml_gemv_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc); ggml_gemv_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
} }
template <> void gemv<block_iq4_nl, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemv_iq4_nl_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
}
// gemm // gemm
template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PARAM_TYPE> template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PARAM_TYPE>
void gemm(int, float *, size_t, const void *, const void *, int, int); void gemm(int, float *, size_t, const void *, const void *, int, int);
@ -1423,6 +1562,10 @@ template <> void gemm<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size
ggml_gemm_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc); ggml_gemm_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
} }
template <> void gemm<block_iq4_nl, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemm_iq4_nl_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
}
class tensor_traits_base : public ggml::cpu::tensor_traits { class tensor_traits_base : public ggml::cpu::tensor_traits {
public: public:
virtual int repack(struct ggml_tensor * t, const void * data, size_t data_size) = 0; virtual int repack(struct ggml_tensor * t, const void * data, size_t data_size) = 0;
@ -1706,6 +1849,7 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
// instance for IQ4 // instance for IQ4
static const ggml::cpu::repack::tensor_traits<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0> iq4_nl_4x4_q8_0; static const ggml::cpu::repack::tensor_traits<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0> iq4_nl_4x4_q8_0;
static const ggml::cpu::repack::tensor_traits<block_iq4_nl, 8, 8, GGML_TYPE_Q8_0> iq4_nl_8x8_q8_0;
bool permit_repack = true; bool permit_repack = true;
#if defined(GGML_USE_CLBLAST) #if defined(GGML_USE_CLBLAST)
@ -1741,6 +1885,11 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
} }
} }
} else if (cur->type == GGML_TYPE_IQ4_NL) { } else if (cur->type == GGML_TYPE_IQ4_NL) {
if (ggml_cpu_has_avx2()) {
if (cur->ne[1] % 8 == 0) {
return &iq4_nl_8x8_q8_0;
}
}
if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) { if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
if (cur->ne[1] % 4 == 0) { if (cur->ne[1] % 4 == 0) {
return &iq4_nl_4x4_q8_0; return &iq4_nl_4x4_q8_0;

View file

@ -67,6 +67,13 @@ struct block_iq4_nlx4 {
static_assert(sizeof(block_iq4_nlx4) == 4 * sizeof(ggml_half) + QK4_NL * 2, "wrong iq4_nlx4 block size/padding"); static_assert(sizeof(block_iq4_nlx4) == 4 * sizeof(ggml_half) + QK4_NL * 2, "wrong iq4_nlx4 block size/padding");
struct block_iq4_nlx8 {
ggml_half d[8]; // deltas for 8 iq4_nl blocks
uint8_t qs[QK4_NL * 4]; // nibbles / quants for 8 iq4_nl blocks
};
static_assert(sizeof(block_iq4_nlx8) == 8 * sizeof(ggml_half) + QK4_NL * 4, "wrong iq4_nlx8 block size/padding");
#if defined(__cplusplus) #if defined(__cplusplus)
extern "C" { extern "C" {
#endif #endif
@ -80,12 +87,14 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
// Native implementations // Native implementations
void ggml_quantize_mat_q8_0_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_0_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
@ -97,12 +106,14 @@ void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs,
void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
#if defined(__cplusplus) #if defined(__cplusplus)
} // extern "C" } // extern "C"

View file

@ -87,6 +87,10 @@
#define GGML_CUDA_CC_IS_QY2(cc) (cc >= GGML_CUDA_CC_QY2 && cc < GGML_CUDA_CC_NG) #define GGML_CUDA_CC_IS_QY2(cc) (cc >= GGML_CUDA_CC_QY2 && cc < GGML_CUDA_CC_NG)
#define GGML_CUDA_CC_IS_NG(cc) (cc >= GGML_CUDA_CC_NG) #define GGML_CUDA_CC_IS_NG(cc) (cc >= GGML_CUDA_CC_NG)
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
# define GGML_CUDA_USE_CUB
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
#ifdef __CUDA_ARCH_LIST__ #ifdef __CUDA_ARCH_LIST__
constexpr bool ggml_cuda_has_arch_impl(int) { constexpr bool ggml_cuda_has_arch_impl(int) {
return false; return false;
@ -424,26 +428,6 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
#endif // FP16_AVAILABLE #endif // FP16_AVAILABLE
} }
// Row reduction kernel template - compute sum (norm=false) or mean (norm=true)
template<bool norm>
static __global__ void reduce_rows_f32(const float * x, float * dst, const int ncols) {
const int row = blockIdx.x;
const int col = threadIdx.x;
float sum = 0.0f;
for (int i = col; i < ncols; i += blockDim.x) {
sum += x[row * ncols + i];
}
sum = warp_reduce_sum(sum);
if (col != 0) {
return;
}
dst[row] = norm ? sum / ncols : sum;
}
template<int width = WARP_SIZE> template<int width = WARP_SIZE>
static __device__ __forceinline__ int warp_reduce_all(int x) { static __device__ __forceinline__ int warp_reduce_all(int x) {
#ifdef GGML_USE_HIP #ifdef GGML_USE_HIP
@ -484,25 +468,21 @@ static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b
} }
static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const half2 b) { static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const half2 b) {
#if defined(GGML_USE_HIP) && HIP_VERSION >= 50700000 #if defined(GGML_USE_HIP)
return half2(__hmax(a.x, b.x), __hmax(a.y, b.y)); return half2(__hmax(a.x, b.x), __hmax(a.y, b.y));
#elif !defined(GGML_USE_HIP) && CUDART_VERSION >= CUDART_HMAX #elif CUDART_VERSION >= CUDART_HMAX
return __hmax2(a, b); return __hmax2(a, b);
#elif !defined(GGML_USE_HIP) #else
half2 ret; half2 ret;
reinterpret_cast<half&>(ret.x) = __float2half(fmaxf( __low2float(a), __low2float(b))); reinterpret_cast<half&>(ret.x) = __float2half(fmaxf( __low2float(a), __low2float(b)));
reinterpret_cast<half&>(ret.y) = __float2half(fmaxf(__high2float(a), __high2float(b))); reinterpret_cast<half&>(ret.y) = __float2half(fmaxf(__high2float(a), __high2float(b)));
return ret; return ret;
#else
GGML_UNUSED(a);
GGML_UNUSED(b);
NO_DEVICE_CODE;
#endif #endif
} }
template<int width = WARP_SIZE> template<int width = WARP_SIZE>
static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL || (defined(GGML_USE_HIP) && HIP_VERSION >= 50700000) #if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL || defined(GGML_USE_HIP)
#pragma unroll #pragma unroll
for (int offset = width/2; offset > 0; offset >>= 1) { for (int offset = width/2; offset > 0; offset >>= 1) {
x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xffffffff, x, offset, width)); x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xffffffff, x, offset, width));
@ -511,7 +491,7 @@ static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
#else #else
GGML_UNUSED(x); GGML_UNUSED(x);
NO_DEVICE_CODE; NO_DEVICE_CODE;
#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL || (defined(GGML_USE_HIP) && HIP_VERSION >= 50700000) #endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL || defined(GGML_USE_HIP)
} }
#if CUDART_VERSION < CUDART_HMASK #if CUDART_VERSION < CUDART_HMASK

View file

@ -15,7 +15,6 @@ namespace wmma = mtmusa::wmma;
namespace wmma = nvcuda::wmma; namespace wmma = nvcuda::wmma;
#endif // GGML_USE_MUSA #endif // GGML_USE_MUSA
#elif defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE) #elif defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)
#undef HIP_ENABLE_WARP_SYNC_BUILTINS // conflicts with rocWMMA headers
#include <rocwmma/rocwmma.hpp> #include <rocwmma/rocwmma.hpp>
namespace wmma = rocwmma; namespace wmma = rocwmma;
#endif // !defined(GGML_USE_HIP) #endif // !defined(GGML_USE_HIP)

View file

@ -30,6 +30,7 @@ bool g_mul_mat_q = true;
#include "ggml-cuda/mmvq.cuh" #include "ggml-cuda/mmvq.cuh"
#include "ggml-cuda/norm.cuh" #include "ggml-cuda/norm.cuh"
#include "ggml-cuda/opt-step-adamw.cuh" #include "ggml-cuda/opt-step-adamw.cuh"
#include "ggml-cuda/opt-step-sgd.cuh"
#include "ggml-cuda/out-prod.cuh" #include "ggml-cuda/out-prod.cuh"
#include "ggml-cuda/pad.cuh" #include "ggml-cuda/pad.cuh"
#include "ggml-cuda/pool2d.cuh" #include "ggml-cuda/pool2d.cuh"
@ -182,30 +183,6 @@ static int ggml_cuda_parse_id(char devName[]) {
#endif // defined(GGML_USE_HIP) #endif // defined(GGML_USE_HIP)
static ggml_cuda_device_info ggml_cuda_init() { static ggml_cuda_device_info ggml_cuda_init() {
#if defined(GGML_USE_HIP)
// Workaround for a rocBLAS bug when using multiple graphics cards:
// https://github.com/ROCmSoftwarePlatform/rocBLAS/issues/1346
{
int major_version = 0;
size_t version_length = 0;
if (rocblas_get_version_string_size(&version_length) == rocblas_status_success) {
std::vector<char> version(version_length+1, '\0');
if (rocblas_get_version_string(version.data(), version.size()) == rocblas_status_success) {
version.resize(::strlen(version.data()));
int parsed_value = 0;
if (std::from_chars(version.data(), version.data() + version.size(), parsed_value).ec == std::errc()) {
major_version = parsed_value;
}
}
}
if (major_version < 4) {
GGML_LOG_DEBUG(GGML_CUDA_NAME " calling rocblas_initialize as a workaround for a rocBLAS bug\n");
rocblas_initialize();
CUDA_CHECK(cudaDeviceSynchronize());
}
}
#endif
ggml_cuda_device_info info = {}; ggml_cuda_device_info info = {};
cudaError_t err = cudaGetDeviceCount(&info.device_count); cudaError_t err = cudaGetDeviceCount(&info.device_count);
@ -2516,6 +2493,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_OPT_STEP_ADAMW: case GGML_OP_OPT_STEP_ADAMW:
ggml_cuda_opt_step_adamw(ctx, dst); ggml_cuda_opt_step_adamw(ctx, dst);
break; break;
case GGML_OP_OPT_STEP_SGD:
ggml_cuda_opt_step_sgd(ctx, dst);
break;
default: default:
return false; return false;
} }
@ -3573,6 +3553,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_CROSS_ENTROPY_LOSS: case GGML_OP_CROSS_ENTROPY_LOSS:
case GGML_OP_CROSS_ENTROPY_LOSS_BACK: case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
case GGML_OP_OPT_STEP_ADAMW: case GGML_OP_OPT_STEP_ADAMW:
case GGML_OP_OPT_STEP_SGD:
return true; return true;
default: default:
return false; return false;

View file

@ -1,4 +1,14 @@
#include "mean.cuh" #include "mean.cuh"
#include "reduce_rows.cuh"
#ifdef GGML_CUDA_USE_CUB
#include <cub/cub.cuh>
using namespace cub;
#endif // GGML_CUDA_USE_CUB
template <typename T> __global__ void divide_by_count(T * result, size_t count) {
*result /= static_cast<T>(count);
}
void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src0 = dst->src[0];
@ -13,7 +23,51 @@ void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const int64_t ncols = src0->ne[0]; const int64_t ncols = src0->ne[0];
const int64_t nrows = ggml_nrows(src0); const int64_t nrows = ggml_nrows(src0);
const dim3 block_dims(WARP_SIZE, 1, 1); // Special case for reducing vectors
const dim3 block_nums(nrows, 1, 1); #ifdef GGML_CUDA_USE_CUB
reduce_rows_f32</*norm*/ true><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols); #ifdef USE_CUDA_GRAPH
cudaStreamCaptureStatus iscapturing;
CUDA_CHECK(cudaStreamIsCapturing(stream, &iscapturing));
#endif // USE_CUDA_GRAPH
if ((nrows == 1) &&
#ifdef USE_CUDA_GRAPH
// CUDA_GRAPHS_DISABLED
((ncols > 65536) &&
((ctx.cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) ||
ctx.cuda_graph->disable_due_to_gpu_arch || ctx.cuda_graph->disable_due_to_too_many_updates ||
ctx.cuda_graph->disable_due_to_failed_graph_capture)) ||
// CUDA_GRAPHS ENABLED
((ncols > 32768) &&
!((ctx.cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) ||
ctx.cuda_graph->disable_due_to_gpu_arch || ctx.cuda_graph->disable_due_to_too_many_updates ||
ctx.cuda_graph->disable_due_to_failed_graph_capture))) {
#else
(ncols > 65536)) {
#endif // USE_CUDA_GRAPH
// Single row - use device-wide reduction
size_t tmp_size = 0;
ggml_cuda_pool & pool = ctx.pool();
DeviceReduce::Sum(nullptr, tmp_size, src0_d, dst_d, ncols, stream);
ggml_cuda_pool_alloc<uint8_t> tmp_alloc(pool, tmp_size);
DeviceReduce::Sum(tmp_alloc.ptr, tmp_size, src0_d, dst_d, ncols, stream);
// Divide by ncols
divide_by_count<float><<<1, 1, 0, stream>>>(dst_d, ncols);
return;
}
#endif // GGML_CUDA_USE_CUB
const dim3 block_nums(nrows, 1, 1);
const int id = ggml_cuda_get_device();
const int nsm = ggml_cuda_info().devices[id].nsm;
if ((nrows / nsm) < 2) {
const dim3 block_dims(512, 1, 1);
reduce_rows_f32</*norm=*/true><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
} else {
const dim3 block_dims(ncols < 1024 ? 32 : 128, 1, 1);
reduce_rows_f32</*norm=*/true><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
}
} }

View file

@ -0,0 +1,49 @@
#include "ggml-impl.h"
#include "opt-step-sgd.cuh"
#include <cstdint>
static __global__ void opt_step_sgd_f32(
float * __restrict__ x, const float * __restrict__ g,
const float * __restrict__ pars, const int64_t k) {
const int64_t i = (int64_t) blockIdx.x*blockDim.x + threadIdx.x;
if (i >= k) {
return;
}
x[i] = x[i] * (1.0f - pars[0] * pars[1]) - pars[0] * g[i];
}
static void opt_step_sgd_f32_cuda(
float * x, const float * g, const float * __restrict__ pars, const int64_t k, cudaStream_t stream) {
const dim3 block_dims(CUDA_OPT_STEP_SGD_BLOCK_SIZE, 1, 1);
const dim3 block_nums((k + CUDA_OPT_STEP_SGD_BLOCK_SIZE - 1) / CUDA_OPT_STEP_SGD_BLOCK_SIZE, 1, 1);
opt_step_sgd_f32<<<block_nums, block_dims, 0, stream>>>(x, g, pars, k);
}
void ggml_cuda_opt_step_sgd(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src0_grad = dst->src[1];
const ggml_tensor * params = dst->src[2];
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src0_grad->type == GGML_TYPE_F32);
GGML_ASSERT(params->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(src0_grad));
GGML_ASSERT(ggml_is_contiguous(params));
GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
GGML_ASSERT(ggml_nelements(params) == 2);
float * src0_d = (float *) src0->data;
const float * src0_grad_d = (const float *) src0_grad->data;
const float * params_d = (const float *) params->data;
cudaStream_t stream = ctx.stream();
const int64_t ne = ggml_nelements(src0);
opt_step_sgd_f32_cuda(src0_d, src0_grad_d, params_d, ne, stream);
}

View file

@ -0,0 +1,5 @@
#include "common.cuh"
#define CUDA_OPT_STEP_SGD_BLOCK_SIZE 256
void ggml_cuda_opt_step_sgd(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View file

@ -0,0 +1,53 @@
#include "common.cuh"
// Row reduction kernel template - compute sum (norm=false) or mean (norm=true)
template <bool norm>
static __global__ void reduce_rows_f32(const float * __restrict__ x, float * __restrict__ dst, const int ncols) {
const int row = blockIdx.x;
const int col = threadIdx.x;
float sum = 0.0f;
const int num_unroll = 8;
float temp[num_unroll];
float sum_temp[num_unroll] = { 0.0f };
for (int i = col; i < ncols;) {
for (int j = 0; j < num_unroll; ++j) {
if (i < ncols) {
temp[j] = x[row * ncols + i];
} else {
temp[j] = 0;
}
i += blockDim.x;
}
for (int j = 0; j < num_unroll; ++j) {
sum_temp[j] += temp[j];
}
}
for (int j = 0; j < num_unroll; ++j) {
sum += sum_temp[j];
}
// sum up partial sums
sum = warp_reduce_sum(sum);
if (blockDim.x > WARP_SIZE) {
assert((blockDim.x <= 1024) && (blockDim.x % WARP_SIZE) == 0);
__shared__ float s_sum[32];
const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
s_sum[warp_id] = sum;
}
__syncthreads();
sum = 0.0f;
if (lane_id < (blockDim.x / WARP_SIZE)) {
sum = s_sum[lane_id];
}
sum = warp_reduce_sum(sum);
}
if (col != 0) {
return;
}
dst[row] = norm ? sum / ncols : sum;
}

View file

@ -1,19 +1,15 @@
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070 #include "sum.cuh"
#define USE_CUB #include "sumrows.cuh"
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
#ifdef USE_CUB #ifdef GGML_CUDA_USE_CUB
#include <cub/cub.cuh> #include <cub/cub.cuh>
using namespace cub; using namespace cub;
#endif // USE_CUB #endif // GGML_CUDA_USE_CUB
#include "sumrows.cuh"
#include "sum.cuh"
#include <cstdint> #include <cstdint>
void sum_f32_cuda(ggml_cuda_pool & pool, const float * x, float * dst, const int64_t ne, cudaStream_t stream) { void sum_f32_cuda(ggml_cuda_pool & pool, const float * x, float * dst, const int64_t ne, cudaStream_t stream) {
#ifdef USE_CUB #ifdef GGML_CUDA_USE_CUB
size_t tmp_size = 0; size_t tmp_size = 0;
DeviceReduce::Sum(nullptr, tmp_size, x, dst, ne, stream); DeviceReduce::Sum(nullptr, tmp_size, x, dst, ne, stream);
ggml_cuda_pool_alloc<uint8_t> tmp_alloc(pool, tmp_size); ggml_cuda_pool_alloc<uint8_t> tmp_alloc(pool, tmp_size);
@ -23,7 +19,7 @@ void sum_f32_cuda(ggml_cuda_pool & pool, const float * x, float * dst, const int
// For AMD there is rocPRIM which could be used as a drop-in replacement via hipcub but this would require C++11 -> C++14. // For AMD there is rocPRIM which could be used as a drop-in replacement via hipcub but this would require C++11 -> C++14.
sum_rows_f32_cuda(x, dst, ne, 1, stream); sum_rows_f32_cuda(x, dst, ne, 1, stream);
GGML_UNUSED(pool); GGML_UNUSED(pool);
#endif // USE_CUB #endif // GGML_CUDA_USE_CUB
} }
void ggml_cuda_op_sum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void ggml_cuda_op_sum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {

View file

@ -1,9 +1,17 @@
#include "reduce_rows.cuh"
#include "sumrows.cuh" #include "sumrows.cuh"
void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) { void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
const dim3 block_dims(WARP_SIZE, 1, 1); const int id = ggml_cuda_get_device();
const int nsm = ggml_cuda_info().devices[id].nsm;
const dim3 block_nums(nrows, 1, 1); const dim3 block_nums(nrows, 1, 1);
reduce_rows_f32</*norm*/false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols); if ((nrows / nsm) < 2) {
const dim3 block_dims(512, 1, 1);
reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
} else {
const dim3 block_dims(ncols < 1024 ? 32 : 128, 1, 1);
reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
}
} }
void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@ -19,8 +27,17 @@ void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const int64_t ncols = src0->ne[0]; const int64_t ncols = src0->ne[0];
const int64_t nrows = ggml_nrows(src0); const int64_t nrows = ggml_nrows(src0);
const dim3 block_dims(WARP_SIZE, 1, 1);
const dim3 block_nums(nrows, 1, 1); const dim3 block_nums(nrows, 1, 1);
const int id = ggml_cuda_get_device();
const int nsm = ggml_cuda_info().devices[id].nsm;
if ((nrows / nsm) < 2) {
// Increase num threads to 512 for small nrows to better hide the latency
const dim3 block_dims(512, 1, 1);
reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
} else {
// Enough active SMs to hide latency, use smaller blocks to allow better scheduling
const dim3 block_dims(ncols < 1024 ? 32 : 128, 1, 1);
reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols); reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
} }
}

View file

@ -1,12 +1,10 @@
#pragma once #pragma once
#define HIP_ENABLE_WARP_SYNC_BUILTINS 1 #define HIP_DISABLE_WARP_SYNC_BUILTINS 1
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
#include <hipblas/hipblas.h> #include <hipblas/hipblas.h>
#include <hip/hip_fp16.h> #include <hip/hip_fp16.h>
#include <hip/hip_bfloat16.h> #include <hip/hip_bfloat16.h>
// for rocblas_initialize()
#include "rocblas/rocblas.h"
#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT #define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT #define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT
@ -251,17 +249,3 @@ static __device__ __forceinline__ unsigned int __vcmpne4(unsigned int a, unsigne
} }
return c; return c;
} }
#if HIP_VERSION < 50600000
// __shfl_xor() for half2 was added in ROCm 5.6
static __device__ __forceinline__ half2 __shfl_xor(half2 var, int laneMask, int width) {
typedef union half2_b32 {
half2 val;
int b32;
} half2_b32_t;
half2_b32_t tmp;
tmp.val = var;
tmp.b32 = __shfl_xor(tmp.b32, laneMask, width);
return tmp.val;
}
#endif // HIP_VERSION < 50600000

View file

@ -66,7 +66,9 @@ struct ggml_opt_context {
ggml_opt_get_optimizer_params get_opt_pars = nullptr; ggml_opt_get_optimizer_params get_opt_pars = nullptr;
void * get_opt_pars_ud = nullptr; void * get_opt_pars_ud = nullptr;
struct ggml_tensor * adamw_params = nullptr; struct ggml_tensor * opt_step_params = nullptr; // Stores output of get_opt_pars.
enum ggml_opt_optimizer_type optimizer = GGML_OPT_OPTIMIZER_TYPE_ADAMW;
}; };
struct ggml_opt_result { struct ggml_opt_result {
@ -229,9 +231,13 @@ struct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params(void * us
result.adamw.eps = 1e-8f; result.adamw.eps = 1e-8f;
result.adamw.wd = 0.0f; result.adamw.wd = 0.0f;
result.sgd.alpha = 1e-3f;
result.sgd.wd = 0.0f;
return result; return result;
} }
struct ggml_opt_optimizer_params ggml_opt_get_constant_optimizer_params(void * userdata) { struct ggml_opt_optimizer_params ggml_opt_get_constant_optimizer_params(void * userdata) {
return *((struct ggml_opt_optimizer_params *) userdata); return *((struct ggml_opt_optimizer_params *) userdata);
} }
@ -249,6 +255,7 @@ struct ggml_opt_params ggml_opt_default_params(
/*opt_period =*/ 1, /*opt_period =*/ 1,
/*get_opt_pars =*/ ggml_opt_get_default_optimizer_params, /*get_opt_pars =*/ ggml_opt_get_default_optimizer_params,
/*get_opt_pars_ud =*/ nullptr, /*get_opt_pars_ud =*/ nullptr,
/*optimizer =*/ GGML_OPT_OPTIMIZER_TYPE_ADAMW,
}; };
} }
@ -316,9 +323,14 @@ static void ggml_opt_build(ggml_opt_context_t opt_ctx) {
GGML_ASSERT(opt_ctx->ctx_compute && "no compute context set, either use static graphs or set one with ggml_opt_prepare_alloc"); GGML_ASSERT(opt_ctx->ctx_compute && "no compute context set, either use static graphs or set one with ggml_opt_prepare_alloc");
GGML_ASSERT((!opt_ctx->static_graphs || opt_ctx->inputs->data) && "when using static graphs the inputs must be allocated statically"); GGML_ASSERT((!opt_ctx->static_graphs || opt_ctx->inputs->data) && "when using static graphs the inputs must be allocated statically");
const enum ggml_opt_optimizer_type optimizer = opt_ctx->optimizer;
const bool accumulate = opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_GRAD && const bool accumulate = opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_GRAD &&
!(opt_ctx->static_graphs && opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT && opt_ctx->opt_period == 1); !(opt_ctx->static_graphs && opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT && opt_ctx->opt_period == 1);
const bool need_momenta = opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT &&
opt_ctx->optimizer == GGML_OPT_OPTIMIZER_TYPE_ADAMW;
ggml_set_input(opt_ctx->inputs); ggml_set_input(opt_ctx->inputs);
ggml_set_output(opt_ctx->outputs); ggml_set_output(opt_ctx->outputs);
@ -340,8 +352,7 @@ static void ggml_opt_build(ggml_opt_context_t opt_ctx) {
// - pred (if using static graphs) // - pred (if using static graphs)
// - ncorrect (if using static graphs, 2 tensors). // - ncorrect (if using static graphs, 2 tensors).
constexpr size_t n_loss = 1; constexpr size_t n_loss = 1;
const size_t tensors_per_param = (accumulate ? 1 : 0) + const size_t tensors_per_param = (accumulate ? 1 : 0) + (need_momenta ? 2 : 0);
(opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT ? 2 : 0);
const size_t tensors_const = opt_ctx->static_graphs ? 9 : 0; const size_t tensors_const = opt_ctx->static_graphs ? 9 : 0;
const size_t size_meta = (n_loss + tensors_per_param*n_param + tensors_const) * ggml_tensor_overhead(); const size_t size_meta = (n_loss + tensors_per_param*n_param + tensors_const) * ggml_tensor_overhead();
struct ggml_init_params params = { struct ggml_init_params params = {
@ -458,7 +469,7 @@ static void ggml_opt_build(ggml_opt_context_t opt_ctx) {
} }
} }
if (opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_OPT) { if (need_momenta && opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_OPT) {
opt_ctx->grad_m.resize(n_nodes); opt_ctx->grad_m.resize(n_nodes);
opt_ctx->grad_v.resize(n_nodes); opt_ctx->grad_v.resize(n_nodes);
for (int i = 0; i < n_nodes; ++i) { for (int i = 0; i < n_nodes; ++i) {
@ -492,23 +503,36 @@ static void ggml_opt_build(ggml_opt_context_t opt_ctx) {
// gb_opt == graph backward optimize, forward pass, then backward pass to calculate gradients, then optimizer step. // gb_opt == graph backward optimize, forward pass, then backward pass to calculate gradients, then optimizer step.
opt_ctx->gb_opt = ggml_graph_dup(opt_ctx->ctx_compute, opt_ctx->gb_grad, /*force_grads =*/ true); opt_ctx->gb_opt = ggml_graph_dup(opt_ctx->ctx_compute, opt_ctx->gb_grad, /*force_grads =*/ true);
opt_ctx->adamw_params = ggml_new_tensor_1d(opt_ctx->ctx_cpu, GGML_TYPE_F32, 7); opt_ctx->opt_step_params = ggml_new_tensor_1d(opt_ctx->ctx_cpu, GGML_TYPE_F32, need_momenta ? 7 : 2);
ggml_set_input(opt_ctx->adamw_params); ggml_tensor * adamw_params = opt_ctx->opt_step_params;
ggml_set_name(opt_ctx->adamw_params, "adamw_params"); ggml_set_input(adamw_params);
const char * optimizer_name = ggml_opt_optimizer_name(opt_ctx->optimizer);
ggml_format_name(adamw_params, "%s_params", optimizer_name);
for (int i = opt_ctx->gf->n_nodes-1; i >= 0; --i) { for (int i = opt_ctx->gf->n_nodes-1; i >= 0; --i) {
struct ggml_tensor * node = opt_ctx->gb_opt->nodes[i]; struct ggml_tensor * node = opt_ctx->gb_opt->nodes[i];
struct ggml_tensor * grad = ggml_graph_get_grad(opt_ctx->gb_opt, node); struct ggml_tensor * grad = ggml_graph_get_grad(opt_ctx->gb_opt, node);
if (grad && (node->flags & GGML_TENSOR_FLAG_PARAM)) { if (grad && (node->flags & GGML_TENSOR_FLAG_PARAM)) {
struct ggml_tensor * m = opt_ctx->grad_m[i]; struct ggml_tensor * m = nullptr;
struct ggml_tensor * v = opt_ctx->grad_v[i]; struct ggml_tensor * v = nullptr;
struct ggml_tensor * opt_step = ggml_opt_step_adamw(opt_ctx->ctx_compute, node, grad, m, v, opt_ctx->adamw_params); if (need_momenta) {
m = opt_ctx->grad_m[i];
ggml_set_name(m, (std::string("AdamW m for ") + std::string(node->name)).c_str()); v = opt_ctx->grad_v[i];
ggml_set_name(v, (std::string("AdamW v for ") + std::string(node->name)).c_str()); ggml_format_name(m, "AdamW m for %s", node->name);
ggml_set_name(opt_step, (std::string("AdamW step for ") + std::string(node->name)).c_str()); ggml_format_name(v, "AdamW v for %s", node->name);
}
struct ggml_tensor * opt_step;
switch (optimizer) {
case GGML_OPT_OPTIMIZER_TYPE_ADAMW:
opt_step = ggml_opt_step_adamw(opt_ctx->ctx_compute, node, grad, m, v, adamw_params);
break;
case GGML_OPT_OPTIMIZER_TYPE_SGD:
opt_step = ggml_opt_step_sgd(opt_ctx->ctx_compute, node, grad, adamw_params);
break;
default:
GGML_ABORT("fatal error");
}
ggml_format_name(opt_step, "%s step for %s", optimizer_name, node->name);
ggml_build_forward_expand(opt_ctx->gb_opt, opt_step); ggml_build_forward_expand(opt_ctx->gb_opt, opt_step);
} }
} }
@ -534,6 +558,7 @@ ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params) {
result->opt_period = params.opt_period; result->opt_period = params.opt_period;
result->get_opt_pars = params.get_opt_pars; result->get_opt_pars = params.get_opt_pars;
result->get_opt_pars_ud = params.get_opt_pars_ud; result->get_opt_pars_ud = params.get_opt_pars_ud;
result->optimizer = params.optimizer;
GGML_ASSERT(result->opt_period >= 1); GGML_ASSERT(result->opt_period >= 1);
@ -756,8 +781,10 @@ void ggml_opt_alloc(ggml_opt_context_t opt_ctx, bool backward) {
void ggml_opt_eval(ggml_opt_context_t opt_ctx, ggml_opt_result_t result) { void ggml_opt_eval(ggml_opt_context_t opt_ctx, ggml_opt_result_t result) {
GGML_ASSERT(opt_ctx->eval_ready); GGML_ASSERT(opt_ctx->eval_ready);
if (opt_ctx->allocated_graph == opt_ctx->gb_opt) { if (opt_ctx->allocated_graph == opt_ctx->gb_opt) {
struct ggml_opt_optimizer_params opt_pars = opt_ctx->get_opt_pars(opt_ctx->get_opt_pars_ud); const ggml_opt_optimizer_params & opt_pars = opt_ctx->get_opt_pars(opt_ctx->get_opt_pars_ud);
switch (opt_ctx->optimizer) {
case GGML_OPT_OPTIMIZER_TYPE_ADAMW: {
GGML_ASSERT(opt_pars.adamw.alpha > 0.0f); GGML_ASSERT(opt_pars.adamw.alpha > 0.0f);
GGML_ASSERT(opt_pars.adamw.beta1 >= 0.0f); GGML_ASSERT(opt_pars.adamw.beta1 >= 0.0f);
GGML_ASSERT(opt_pars.adamw.beta1 <= 1.0f); GGML_ASSERT(opt_pars.adamw.beta1 <= 1.0f);
@ -771,7 +798,7 @@ void ggml_opt_eval(ggml_opt_context_t opt_ctx, ggml_opt_result_t result) {
const float beta1h = 1.0f / (1.0f - powf(opt_pars.adamw.beta1, opt_ctx->iter)); const float beta1h = 1.0f / (1.0f - powf(opt_pars.adamw.beta1, opt_ctx->iter));
const float beta2h = 1.0f / (1.0f - powf(opt_pars.adamw.beta2, opt_ctx->iter)); const float beta2h = 1.0f / (1.0f - powf(opt_pars.adamw.beta2, opt_ctx->iter));
float * adamw_par_data = ggml_get_data_f32(opt_ctx->adamw_params); float * adamw_par_data = ggml_get_data_f32(opt_ctx->opt_step_params);
adamw_par_data[0] = opt_pars.adamw.alpha; adamw_par_data[0] = opt_pars.adamw.alpha;
adamw_par_data[1] = opt_pars.adamw.beta1; adamw_par_data[1] = opt_pars.adamw.beta1;
adamw_par_data[2] = opt_pars.adamw.beta2; adamw_par_data[2] = opt_pars.adamw.beta2;
@ -779,6 +806,18 @@ void ggml_opt_eval(ggml_opt_context_t opt_ctx, ggml_opt_result_t result) {
adamw_par_data[4] = opt_pars.adamw.wd; adamw_par_data[4] = opt_pars.adamw.wd;
adamw_par_data[5] = beta1h; adamw_par_data[5] = beta1h;
adamw_par_data[6] = beta2h; adamw_par_data[6] = beta2h;
} break;
case GGML_OPT_OPTIMIZER_TYPE_SGD: {
GGML_ASSERT(opt_pars.sgd.alpha > 0.0f);
GGML_ASSERT(opt_pars.sgd.wd >= 0.0f);
GGML_ASSERT(opt_pars.sgd.wd <= 1.0f);
float * sgd = ggml_get_data_f32(opt_ctx->opt_step_params);
sgd[0] = opt_pars.sgd.alpha;
sgd[1] = opt_pars.sgd.wd;
} break;
default:
GGML_ABORT("fatal error");
}
} }
ggml_backend_sched_graph_compute(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy); ggml_backend_sched_graph_compute(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy);
@ -963,6 +1002,7 @@ void ggml_opt_fit(
ggml_tensor * outputs, ggml_tensor * outputs,
ggml_opt_dataset_t dataset, ggml_opt_dataset_t dataset,
enum ggml_opt_loss_type loss_type, enum ggml_opt_loss_type loss_type,
enum ggml_opt_optimizer_type optimizer,
ggml_opt_get_optimizer_params get_opt_pars, ggml_opt_get_optimizer_params get_opt_pars,
int64_t nepoch, int64_t nepoch,
int64_t nbatch_logical, int64_t nbatch_logical,
@ -993,6 +1033,7 @@ void ggml_opt_fit(
params.opt_period = opt_period; params.opt_period = opt_period;
params.get_opt_pars = get_opt_pars; params.get_opt_pars = get_opt_pars;
params.get_opt_pars_ud = &epoch; params.get_opt_pars_ud = &epoch;
params.optimizer = optimizer;
ggml_opt_context_t opt_ctx = ggml_opt_init(params); ggml_opt_context_t opt_ctx = ggml_opt_init(params);
// Shuffling the data is generally useful but there is only a point if not all data is used in a single batch. // Shuffling the data is generally useful but there is only a point if not all data is used in a single batch.
@ -1035,3 +1076,18 @@ void ggml_opt_fit(
ggml_opt_result_free(result_train); ggml_opt_result_free(result_train);
ggml_opt_result_free(result_val); ggml_opt_result_free(result_val);
} }
enum ggml_opt_optimizer_type ggml_opt_context_optimizer_type(ggml_opt_context_t c) {
return c->optimizer;
}
GGML_API const char * ggml_opt_optimizer_name(enum ggml_opt_optimizer_type o) {
switch (o) {
case GGML_OPT_OPTIMIZER_TYPE_ADAMW:
return "adamw";
case GGML_OPT_OPTIMIZER_TYPE_SGD:
return "sgd";
default:
return "undefined";
};
}

View file

@ -526,6 +526,7 @@ struct vk_device_struct {
vk_pipeline pipeline_rwkv_wkv6_f32; vk_pipeline pipeline_rwkv_wkv6_f32;
vk_pipeline pipeline_rwkv_wkv7_f32; vk_pipeline pipeline_rwkv_wkv7_f32;
vk_pipeline pipeline_opt_step_adamw_f32; vk_pipeline pipeline_opt_step_adamw_f32;
vk_pipeline pipeline_opt_step_sgd_f32;
vk_pipeline pipeline_conv2d_f32[CONV_SHAPE_COUNT]; vk_pipeline pipeline_conv2d_f32[CONV_SHAPE_COUNT];
vk_pipeline pipeline_conv2d_f16_f32[CONV_SHAPE_COUNT]; vk_pipeline pipeline_conv2d_f16_f32[CONV_SHAPE_COUNT];
vk_pipeline pipeline_conv2d_dw_whcn_f32; vk_pipeline pipeline_conv2d_dw_whcn_f32;
@ -3139,6 +3140,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_opt_step_sgd_f32, "opt_step_sgd_f32", opt_step_sgd_f32_len, opt_step_sgd_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
// conv2d // conv2d
for (uint32_t s = 0; s < CONV_SHAPE_COUNT; ++s) { for (uint32_t s = 0; s < CONV_SHAPE_COUNT; ++s) {
uint32_t conv2d_WG_SIZE = 256; uint32_t conv2d_WG_SIZE = 256;
@ -7223,6 +7226,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return ctx->device->pipeline_opt_step_adamw_f32; return ctx->device->pipeline_opt_step_adamw_f32;
} }
return nullptr; return nullptr;
case GGML_OP_OPT_STEP_SGD:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_opt_step_sgd_f32;
}
return nullptr;
case GGML_OP_LEAKY_RELU: case GGML_OP_LEAKY_RELU:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_leaky_relu_f32; return ctx->device->pipeline_leaky_relu_f32;
@ -7722,6 +7730,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
ggml_vk_buffer_memset_async(subctx, d_D, d_buf_offset, 0, d_sz); ggml_vk_buffer_memset_async(subctx, d_D, d_buf_offset, 0, d_sz);
ggml_vk_sync_buffers(subctx); ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
} else if (op == GGML_OP_OPT_STEP_SGD) {
// OPT_STEP_SGD works on src0, it does not need dst
ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz } }, pc, elements);
} else if (use_src2) { } else if (use_src2) {
ggml_vk_sync_buffers(subctx); ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
@ -8075,6 +8087,12 @@ static void ggml_vk_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& su
); );
} }
static void ggml_vk_opt_step_sgd(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) {
const size_t n = ggml_nelements(dst->src[0]);
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_OPT_STEP_SGD, { (uint32_t)n, 0, 0.0f, 0.0f }, dryrun);
}
static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
int * op_params = (int *)dst->op_params; int * op_params = (int *)dst->op_params;
@ -9628,6 +9646,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_OP_LEAKY_RELU: case GGML_OP_LEAKY_RELU:
case GGML_OP_FLASH_ATTN_EXT: case GGML_OP_FLASH_ATTN_EXT:
case GGML_OP_OPT_STEP_ADAMW: case GGML_OP_OPT_STEP_ADAMW:
case GGML_OP_OPT_STEP_SGD:
break; break;
default: default:
std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(node->op) << std::endl; std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(node->op) << std::endl;
@ -9692,6 +9711,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_OP_CONV_2D: case GGML_OP_CONV_2D:
case GGML_OP_CONV_2D_DW: case GGML_OP_CONV_2D_DW:
case GGML_OP_LEAKY_RELU: case GGML_OP_LEAKY_RELU:
case GGML_OP_OPT_STEP_SGD:
{ {
// These operations all go through ggml_vk_op_f32, so short-circuit and // These operations all go through ggml_vk_op_f32, so short-circuit and
// do the only thing needed for the dryrun. // do the only thing needed for the dryrun.
@ -9941,6 +9961,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_OP_OPT_STEP_ADAMW: case GGML_OP_OPT_STEP_ADAMW:
ggml_vk_opt_step_adamw(ctx, compute_ctx, node, dryrun); ggml_vk_opt_step_adamw(ctx, compute_ctx, node, dryrun);
break;
case GGML_OP_OPT_STEP_SGD:
ggml_vk_opt_step_sgd(ctx, compute_ctx, src0, src1, src2, node, dryrun);
break; break;
default: default:
return false; return false;
@ -10044,8 +10069,8 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
case GGML_OP_REPEAT: case GGML_OP_REPEAT:
case GGML_OP_REPEAT_BACK: case GGML_OP_REPEAT_BACK:
case GGML_OP_OPT_STEP_ADAMW: case GGML_OP_OPT_STEP_ADAMW:
case GGML_OP_OPT_STEP_SGD:
buf = tensor->buffer; buf = tensor->buffer;
break; break;
case GGML_OP_UNARY: case GGML_OP_UNARY:
switch (ggml_get_unary_op(tensor)) { switch (ggml_get_unary_op(tensor)) {
@ -11184,6 +11209,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_OP_SIN: case GGML_OP_SIN:
case GGML_OP_COS: case GGML_OP_COS:
case GGML_OP_CLAMP: case GGML_OP_CLAMP:
case GGML_OP_LEAKY_RELU:
case GGML_OP_OPT_STEP_ADAMW:
case GGML_OP_OPT_STEP_SGD:
return op->src[0]->type == GGML_TYPE_F32; return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_UPSCALE: case GGML_OP_UPSCALE:
case GGML_OP_ACC: case GGML_OP_ACC:
@ -11205,8 +11233,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_OP_POOL_2D: case GGML_OP_POOL_2D:
case GGML_OP_RWKV_WKV6: case GGML_OP_RWKV_WKV6:
case GGML_OP_RWKV_WKV7: case GGML_OP_RWKV_WKV7:
case GGML_OP_LEAKY_RELU:
case GGML_OP_OPT_STEP_ADAMW:
return true; return true;
case GGML_OP_CONV_TRANSPOSE_1D: case GGML_OP_CONV_TRANSPOSE_1D:
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32; return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
@ -11804,6 +11830,10 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
src_clone[0]->flags = src0->flags; src_clone[0]->flags = src0->flags;
tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1], tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1],
src_clone[2], src_clone[3], src_clone[4]); src_clone[2], src_clone[3], src_clone[4]);
} else if (tensor->op == GGML_OP_OPT_STEP_SGD) {
src_clone[0]->flags = src0->flags;
tensor_clone = ggml_opt_step_sgd(ggml_ctx, src_clone[0], src_clone[1],
src_clone[2]);
} }
else { else {
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl; std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;

View file

@ -0,0 +1,22 @@
#version 450
#include "generic_head.comp"
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) buffer X {A_TYPE data_x[];};
layout (binding = 1) readonly buffer G {A_TYPE data_grad[];};
layout (binding = 2) readonly buffer P {float data_params[2];};
void main() {
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
if (i >= p.KX) {
return;
}
const float alpha = data_params[0];
const float keep = 1.f - alpha * data_params[1];
data_x[i] = data_x[i] * keep - alpha * data_grad[i];
}

View file

@ -671,6 +671,7 @@ void process_shaders() {
string_to_spv("rwkv_wkv7_f32", "wkv7.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); string_to_spv("rwkv_wkv7_f32", "wkv7.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
string_to_spv("opt_step_sgd_f32", "opt_step_sgd.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
string_to_spv("conv2d_f32_unroll", "conv2d_mm.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", "[[unroll]]"}}); string_to_spv("conv2d_f32_unroll", "conv2d_mm.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", "[[unroll]]"}});
string_to_spv("conv2d_f16_f32_unroll", "conv2d_mm.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", "[[unroll]]"}}); string_to_spv("conv2d_f16_f32_unroll", "conv2d_mm.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", "[[unroll]]"}});

View file

@ -1028,11 +1028,12 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"CROSS_ENTROPY_LOSS", "CROSS_ENTROPY_LOSS",
"CROSS_ENTROPY_LOSS_BACK", "CROSS_ENTROPY_LOSS_BACK",
"OPT_STEP_ADAMW", "OPT_STEP_ADAMW",
"OPT_STEP_SGD",
"GLU", "GLU",
}; };
static_assert(GGML_OP_COUNT == 87, "GGML_OP_COUNT != 87"); static_assert(GGML_OP_COUNT == 88, "GGML_OP_COUNT != 88");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none", "none",
@ -1129,15 +1130,15 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"cross_entropy_loss(x,y)", "cross_entropy_loss(x,y)",
"cross_entropy_loss_back(x,y)", "cross_entropy_loss_back(x,y)",
"adamw(x)", "adamw(x)",
"sgd(x)",
"glu(x)", "glu(x)",
}; };
static_assert(GGML_OP_COUNT == 87, "GGML_OP_COUNT != 87"); static_assert(GGML_OP_COUNT == 88, "GGML_OP_COUNT != 88");
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = { static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
"ABS", "ABS",
"SGN", "SGN",
@ -3901,6 +3902,7 @@ static struct ggml_tensor * ggml_rope_impl(
struct ggml_tensor * b, struct ggml_tensor * b,
struct ggml_tensor * c, struct ggml_tensor * c,
int n_dims, int n_dims,
int sections[GGML_MROPE_SECTIONS],
int mode, int mode,
int n_ctx_orig, int n_ctx_orig,
float freq_base, float freq_base,
@ -3914,15 +3916,19 @@ static struct ggml_tensor * ggml_rope_impl(
GGML_ASSERT(ggml_is_vector(b)); GGML_ASSERT(ggml_is_vector(b));
GGML_ASSERT(b->type == GGML_TYPE_I32); GGML_ASSERT(b->type == GGML_TYPE_I32);
bool mrope_used = mode & GGML_ROPE_TYPE_MROPE;
if (mrope_used) {
GGML_ASSERT(a->ne[2] * 4 == b->ne[0]); // mrope expecting 4 position ids per token
} else {
GGML_ASSERT(a->ne[2] == b->ne[0]); GGML_ASSERT(a->ne[2] == b->ne[0]);
}
if (c) { if (c) {
GGML_ASSERT(c->type == GGML_TYPE_F32); GGML_ASSERT(c->type == GGML_TYPE_F32);
GGML_ASSERT(c->ne[0] >= n_dims / 2); GGML_ASSERT(c->ne[0] >= n_dims / 2);
} }
int sections[4] = {0, 0, 0, 0};
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
int32_t params[15] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig }; int32_t params[15] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
@ -3932,7 +3938,11 @@ static struct ggml_tensor * ggml_rope_impl(
memcpy(params + 8, &attn_factor, sizeof(float)); memcpy(params + 8, &attn_factor, sizeof(float));
memcpy(params + 9, &beta_fast, sizeof(float)); memcpy(params + 9, &beta_fast, sizeof(float));
memcpy(params + 10, &beta_slow, sizeof(float)); memcpy(params + 10, &beta_slow, sizeof(float));
memcpy(params + 11, &sections, sizeof(int)*4); if (mrope_used) {
memcpy(params + 11, sections, sizeof(int32_t) * GGML_MROPE_SECTIONS);
} else {
memset(params + 11, 0, sizeof(int32_t) * GGML_MROPE_SECTIONS);
}
ggml_set_op_params(result, params, sizeof(params)); ggml_set_op_params(result, params, sizeof(params));
result->op = GGML_OP_ROPE; result->op = GGML_OP_ROPE;
@ -3950,7 +3960,7 @@ struct ggml_tensor * ggml_rope(
int n_dims, int n_dims,
int mode) { int mode) {
return ggml_rope_impl( return ggml_rope_impl(
ctx, a, b, NULL, n_dims, mode, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, false ctx, a, b, NULL, n_dims, NULL, mode, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, false
); );
} }
@ -3960,7 +3970,7 @@ struct ggml_tensor * ggml_rope_multi(
struct ggml_tensor * b, struct ggml_tensor * b,
struct ggml_tensor * c, struct ggml_tensor * c,
int n_dims, int n_dims,
int sections[4], int sections[GGML_MROPE_SECTIONS],
int mode, int mode,
int n_ctx_orig, int n_ctx_orig,
float freq_base, float freq_base,
@ -3969,36 +3979,31 @@ struct ggml_tensor * ggml_rope_multi(
float attn_factor, float attn_factor,
float beta_fast, float beta_fast,
float beta_slow) { float beta_slow) {
// Multimodal Rotary Position Embedding return ggml_rope_impl(
GGML_ASSERT((mode & 1) == 0 && "mode & 1 == 1 is no longer supported"); ctx, a, b, c, n_dims, sections, mode, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow, false
GGML_ASSERT(ggml_is_vector(b)); );
GGML_ASSERT(b->type == GGML_TYPE_I32);
GGML_ASSERT(a->ne[2] * 4 == b->ne[0]); // mrope expecting 4 position ids per token
if (c) {
GGML_ASSERT(c->type == GGML_TYPE_F32);
GGML_ASSERT(c->ne[0] >= n_dims / 2);
} }
struct ggml_tensor * result = ggml_dup_tensor(ctx, a); struct ggml_tensor * ggml_rope_multi_inplace(
struct ggml_context * ctx,
int32_t params[11 + 4] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig }; struct ggml_tensor * a,
memcpy(params + 5, &freq_base, sizeof(float)); struct ggml_tensor * b,
memcpy(params + 6, &freq_scale, sizeof(float)); struct ggml_tensor * c,
memcpy(params + 7, &ext_factor, sizeof(float)); int n_dims,
memcpy(params + 8, &attn_factor, sizeof(float)); int sections[GGML_MROPE_SECTIONS],
memcpy(params + 9, &beta_fast, sizeof(float)); int mode,
memcpy(params + 10, &beta_slow, sizeof(float)); int n_ctx_orig,
memcpy(&params[11], sections, sizeof(int)*4); float freq_base,
ggml_set_op_params(result, params, sizeof(params)); float freq_scale,
float ext_factor,
result->op = GGML_OP_ROPE; float attn_factor,
result->src[0] = a; float beta_fast,
result->src[1] = b; float beta_slow) {
result->src[2] = c; return ggml_rope_impl(
ctx, a, b, c, n_dims, sections, mode, n_ctx_orig, freq_base, freq_scale,
return result; ext_factor, attn_factor, beta_fast, beta_slow, true
);
} }
struct ggml_tensor * ggml_rope_inplace( struct ggml_tensor * ggml_rope_inplace(
@ -4008,7 +4013,7 @@ struct ggml_tensor * ggml_rope_inplace(
int n_dims, int n_dims,
int mode) { int mode) {
return ggml_rope_impl( return ggml_rope_impl(
ctx, a, b, NULL, n_dims, mode, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, true ctx, a, b, NULL, n_dims, NULL, mode, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, true
); );
} }
@ -4027,7 +4032,7 @@ struct ggml_tensor * ggml_rope_ext(
float beta_fast, float beta_fast,
float beta_slow) { float beta_slow) {
return ggml_rope_impl( return ggml_rope_impl(
ctx, a, b, c, n_dims, mode, n_ctx_orig, freq_base, freq_scale, ctx, a, b, c, n_dims, NULL, mode, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow, false ext_factor, attn_factor, beta_fast, beta_slow, false
); );
} }
@ -4047,7 +4052,7 @@ struct ggml_tensor * ggml_rope_ext_inplace(
float beta_fast, float beta_fast,
float beta_slow) { float beta_slow) {
return ggml_rope_impl( return ggml_rope_impl(
ctx, a, b, c, n_dims, mode, n_ctx_orig, freq_base, freq_scale, ctx, a, b, c, n_dims, NULL, mode, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow, true ext_factor, attn_factor, beta_fast, beta_slow, true
); );
} }
@ -4066,7 +4071,7 @@ struct ggml_tensor * ggml_rope_custom(
float beta_fast, float beta_fast,
float beta_slow) { float beta_slow) {
return ggml_rope_impl( return ggml_rope_impl(
ctx, a, b, NULL, n_dims, mode, n_ctx_orig, freq_base, freq_scale, ctx, a, b, NULL, n_dims, NULL, mode, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow, false ext_factor, attn_factor, beta_fast, beta_slow, false
); );
} }
@ -4085,7 +4090,7 @@ struct ggml_tensor * ggml_rope_custom_inplace(
float beta_fast, float beta_fast,
float beta_slow) { float beta_slow) {
return ggml_rope_impl( return ggml_rope_impl(
ctx, a, b, NULL, n_dims, mode, n_ctx_orig, freq_base, freq_scale, ctx, a, b, NULL, n_dims, NULL, mode, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow, true ext_factor, attn_factor, beta_fast, beta_slow, true
); );
} }
@ -4283,14 +4288,13 @@ struct ggml_tensor * ggml_conv_1d_dw(
int s0, int s0,
int p0, int p0,
int d0) { int d0) {
struct ggml_tensor * new_a = ggml_reshape_4d(ctx, a, a->ne[0], 1, a->ne[1], a->ne[2]);
struct ggml_tensor * new_b = ggml_reshape_4d(ctx, b, b->ne[0], 1, b->ne[1], b->ne[2]); struct ggml_tensor * new_b = ggml_reshape_4d(ctx, b, b->ne[0], 1, b->ne[1], b->ne[2]);
struct ggml_tensor * im2col = ggml_im2col(ctx, new_a, new_b, s0, 0, p0, 0, d0, 0, false, GGML_TYPE_F16); struct ggml_tensor * im2col = ggml_im2col(ctx, a, new_b, s0, 0, p0, 0, d0, 0, false, GGML_TYPE_F16);
struct ggml_tensor * result = ggml_mul_mat(ctx, im2col, a); struct ggml_tensor * result = ggml_mul_mat(ctx, im2col, a);
result = ggml_reshape_3d(ctx, result, b->ne[0], b->ne[1], 1); result = ggml_reshape_3d(ctx, result, result->ne[0], result->ne[2], 1);
return result; return result;
} }
@ -5618,6 +5622,28 @@ struct ggml_tensor * ggml_opt_step_adamw(
return result; return result;
} }
// opt_step_sgd
struct ggml_tensor * ggml_opt_step_sgd(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * grad,
struct ggml_tensor * params) {
GGML_ASSERT(a->flags & GGML_TENSOR_FLAG_PARAM);
GGML_ASSERT(ggml_are_same_shape(a, grad));
GGML_ASSERT(params->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_nelements(params) == 2);
struct ggml_tensor * result = ggml_view_tensor(ctx, a);
result->op = GGML_OP_OPT_STEP_SGD;
result->src[0] = a;
result->src[1] = grad;
result->src[2] = params;
return result;
}
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
struct ggml_hash_set ggml_hash_set_new(size_t size) { struct ggml_hash_set ggml_hash_set_new(size_t size) {

View file

@ -873,6 +873,29 @@ extern "C" {
size_t n_token_capacity, size_t n_token_capacity,
size_t * n_token_count_out); size_t * n_token_count_out);
#define LLAMA_STATE_SEQ_FLAGS_SWA_ONLY 1
typedef uint32_t llama_state_seq_flags;
LLAMA_API size_t llama_state_seq_get_size_ext(
struct llama_context * ctx,
llama_seq_id seq_id,
llama_state_seq_flags flags);
LLAMA_API size_t llama_state_seq_get_data_ext(
struct llama_context * ctx,
uint8_t * dst,
size_t size,
llama_seq_id seq_id,
llama_state_seq_flags flags);
LLAMA_API size_t llama_state_seq_set_data_ext(
struct llama_context * ctx,
const uint8_t * src,
size_t size,
llama_seq_id dest_seq_id,
llama_state_seq_flags flags);
// //
// Decoding // Decoding
// //
@ -1440,6 +1463,8 @@ extern "C" {
ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters
void * get_opt_pars_ud; // userdata for calculating optimizer parameters void * get_opt_pars_ud; // userdata for calculating optimizer parameters
enum ggml_opt_optimizer_type optimizer_type;
}; };
LLAMA_API void llama_opt_init(struct llama_context * lctx, struct llama_model * model, struct llama_opt_params lopt_params); LLAMA_API void llama_opt_init(struct llama_context * lctx, struct llama_model * model, struct llama_opt_params lopt_params);

View file

@ -476,7 +476,7 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch, bool sequential) { llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch, bool sequential) {
if (sequential && has_cpl) { if (sequential && has_cpl) {
LLAMA_LOG_ERROR("%s: sequential split is not supported when there are coupled sequences in the input batch\n", __func__); LLAMA_LOG_ERROR("%s: sequential split is not supported when there are coupled sequences in the input batch (you may need to use the -kvu flag)\n", __func__);
return {}; return {};
} }

View file

@ -1657,30 +1657,30 @@ size_t llama_context::state_set_data(const uint8_t * src, size_t size) {
} }
} }
size_t llama_context::state_seq_get_size(llama_seq_id seq_id) { size_t llama_context::state_seq_get_size(llama_seq_id seq_id, llama_state_seq_flags flags) {
llama_io_write_dummy io; llama_io_write_dummy io;
try { try {
return state_seq_write_data(io, seq_id); return state_seq_write_data(io, seq_id, flags);
} catch (const std::exception & err) { } catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what()); LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what());
return 0; return 0;
} }
} }
size_t llama_context::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size) { size_t llama_context::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size, llama_state_seq_flags flags) {
llama_io_write_buffer io(dst, size); llama_io_write_buffer io(dst, size);
try { try {
return state_seq_write_data(io, seq_id); return state_seq_write_data(io, seq_id, flags);
} catch (const std::exception & err) { } catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what()); LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what());
return 0; return 0;
} }
} }
size_t llama_context::state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size) { size_t llama_context::state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size, llama_state_seq_flags flags) {
llama_io_read_buffer io(src, size); llama_io_read_buffer io(src, size);
try { try {
return state_seq_read_data(io, seq_id); return state_seq_read_data(io, seq_id, flags);
} catch (const std::exception & err) { } catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what()); LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what());
return 0; return 0;
@ -1778,7 +1778,7 @@ size_t llama_context::state_seq_load_file(llama_seq_id seq_id, const char * file
{ {
const size_t state_size = file.size() - file.tell(); const size_t state_size = file.size() - file.tell();
llama_io_read_file io(&file); llama_io_read_file io(&file);
const size_t nread = state_seq_read_data(io, seq_id); const size_t nread = state_seq_read_data(io, seq_id, 0);
if (!nread) { if (!nread) {
LLAMA_LOG_ERROR("%s: failed to restore sequence state\n", __func__); LLAMA_LOG_ERROR("%s: failed to restore sequence state\n", __func__);
return 0; return 0;
@ -1802,7 +1802,7 @@ size_t llama_context::state_seq_save_file(llama_seq_id seq_id, const char * file
// save the context state using stream saving // save the context state using stream saving
llama_io_write_file io(&file); llama_io_write_file io(&file);
state_seq_write_data(io, seq_id); state_seq_write_data(io, seq_id, 0);
const size_t res = file.tell(); const size_t res = file.tell();
GGML_ASSERT(res == sizeof(uint32_t) * 3 + sizeof(llama_token) * n_token_count + io.n_bytes()); GGML_ASSERT(res == sizeof(uint32_t) * 3 + sizeof(llama_token) * n_token_count + io.n_bytes());
@ -1971,21 +1971,21 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
return io.n_bytes(); return io.n_bytes();
} }
size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id) { size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
GGML_UNUSED(seq_id); GGML_UNUSED(seq_id);
if (memory) { if (memory) {
memory->state_write(io, seq_id); memory->state_write(io, seq_id, flags);
} }
return io.n_bytes(); return io.n_bytes();
} }
size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id) { size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
GGML_UNUSED(seq_id); GGML_UNUSED(seq_id);
if (memory) { if (memory) {
memory->state_read(io, seq_id); memory->state_read(io, seq_id, flags);
} }
return io.n_bytes(); return io.n_bytes();
@ -2048,7 +2048,7 @@ void llama_context::opt_init(struct llama_model * model, struct llama_opt_params
opt_params.opt_period = n_batch / n_ubatch; opt_params.opt_period = n_batch / n_ubatch;
opt_params.get_opt_pars = lopt_params.get_opt_pars; opt_params.get_opt_pars = lopt_params.get_opt_pars;
opt_params.get_opt_pars_ud = lopt_params.get_opt_pars_ud; opt_params.get_opt_pars_ud = lopt_params.get_opt_pars_ud;
opt_params.optimizer = lopt_params.optimizer_type;
opt_ctx = ggml_opt_init(opt_params); opt_ctx = ggml_opt_init(opt_params);
llama_opt_param_filter param_filter = lopt_params.param_filter; llama_opt_param_filter param_filter = lopt_params.param_filter;
@ -2801,19 +2801,31 @@ bool llama_state_save_file(llama_context * ctx, const char * path_session, const
} }
size_t llama_state_seq_get_size(llama_context * ctx, llama_seq_id seq_id) { size_t llama_state_seq_get_size(llama_context * ctx, llama_seq_id seq_id) {
return ctx->state_seq_get_size(seq_id); return llama_state_seq_get_size_ext(ctx, seq_id, 0);
} }
size_t llama_state_seq_get_data(llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id) { size_t llama_state_seq_get_data(llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id) {
ctx->synchronize(); return llama_state_seq_get_data_ext(ctx, dst, size, seq_id, 0);
return ctx->state_seq_get_data(seq_id, dst, size);
} }
size_t llama_state_seq_set_data(llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id) { size_t llama_state_seq_set_data(llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id) {
return llama_state_seq_set_data_ext(ctx, src, size, seq_id, 0);
}
size_t llama_state_seq_get_size_ext(llama_context * ctx, llama_seq_id seq_id, llama_state_seq_flags flags) {
return ctx->state_seq_get_size(seq_id, flags);
}
size_t llama_state_seq_get_data_ext(llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id, llama_state_seq_flags flags) {
ctx->synchronize(); ctx->synchronize();
return ctx->state_seq_set_data(seq_id, src, size); return ctx->state_seq_get_data(seq_id, dst, size, flags);
}
size_t llama_state_seq_set_data_ext(llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id, llama_state_seq_flags flags) {
ctx->synchronize();
return ctx->state_seq_set_data(seq_id, src, size, flags);
} }
size_t llama_state_seq_save_file(llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) { size_t llama_state_seq_save_file(llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) {

View file

@ -111,9 +111,9 @@ struct llama_context {
size_t state_get_data( uint8_t * dst, size_t size); size_t state_get_data( uint8_t * dst, size_t size);
size_t state_set_data(const uint8_t * src, size_t size); size_t state_set_data(const uint8_t * src, size_t size);
size_t state_seq_get_size(llama_seq_id seq_id); size_t state_seq_get_size(llama_seq_id seq_id, llama_state_seq_flags flags);
size_t state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size); size_t state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size, llama_state_seq_flags flags);
size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size); size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size, llama_state_seq_flags flags);
bool state_load_file( bool state_load_file(
const char * filepath, const char * filepath,
@ -152,6 +152,7 @@ struct llama_context {
void opt_init(struct llama_model * model, struct llama_opt_params lopt_params); void opt_init(struct llama_model * model, struct llama_opt_params lopt_params);
// TODO: more flexible combinations of logical/physical batch size and context size
void opt_epoch( void opt_epoch(
ggml_opt_dataset_t dataset, ggml_opt_dataset_t dataset,
ggml_opt_result_t result_train, ggml_opt_result_t result_train,
@ -212,8 +213,8 @@ private:
size_t state_write_data(llama_io_write_i & io); size_t state_write_data(llama_io_write_i & io);
size_t state_read_data (llama_io_read_i & io); size_t state_read_data (llama_io_read_i & io);
size_t state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id); size_t state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags);
size_t state_seq_read_data (llama_io_read_i & io, llama_seq_id seq_id); size_t state_seq_read_data (llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags);
// //
// members // members

View file

@ -198,14 +198,20 @@ bool llama_kv_cache_unified_iswa::get_can_shift() const {
return kv_base->get_size() == kv_swa->get_size(); return kv_base->get_size() == kv_swa->get_size();
} }
void llama_kv_cache_unified_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id) const { void llama_kv_cache_unified_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
kv_base->state_write(io, seq_id); if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) {
kv_swa ->state_write(io, seq_id); kv_base->state_write(io, seq_id, flags);
} }
void llama_kv_cache_unified_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id) { kv_swa->state_write(io, seq_id, flags);
kv_base->state_read(io, seq_id); }
kv_swa ->state_read(io, seq_id);
void llama_kv_cache_unified_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) {
kv_base->state_read(io, seq_id, flags);
}
kv_swa->state_read(io, seq_id, flags);
} }
llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_base() const { llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_base() const {

View file

@ -56,8 +56,8 @@ public:
// state write/load // state write/load
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override; void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
// //
// llama_kv_cache_unified_iswa specific API // llama_kv_cache_unified_iswa specific API

View file

@ -1828,7 +1828,9 @@ bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const {
return false; return false;
} }
void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id) const { void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
GGML_UNUSED(flags);
io.write(&n_stream, sizeof(n_stream)); io.write(&n_stream, sizeof(n_stream));
for (uint32_t s = 0; s < n_stream; ++s) { for (uint32_t s = 0; s < n_stream; ++s) {
@ -1879,7 +1881,9 @@ void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq
} }
} }
void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_id) { void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
GGML_UNUSED(flags);
GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size())); GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()));
uint32_t n_stream_cur; uint32_t n_stream_cur;

View file

@ -136,8 +136,8 @@ public:
// state write/load // state write/load
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override; void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
// //
// llama_kv_cache_unified specific API // llama_kv_cache_unified specific API

View file

@ -165,12 +165,16 @@ llama_pos llama_memory_hybrid::seq_pos_max(llama_seq_id seq_id) const {
return std::min(mem_attn->seq_pos_max(seq_id), mem_recr->seq_pos_max(seq_id)); return std::min(mem_attn->seq_pos_max(seq_id), mem_recr->seq_pos_max(seq_id));
} }
void llama_memory_hybrid::state_write(llama_io_write_i & io, llama_seq_id seq_id) const { void llama_memory_hybrid::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
GGML_UNUSED(flags);
mem_attn->state_write(io, seq_id); mem_attn->state_write(io, seq_id);
mem_recr->state_write(io, seq_id); mem_recr->state_write(io, seq_id);
} }
void llama_memory_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id) { void llama_memory_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
GGML_UNUSED(flags);
mem_attn->state_read(io, seq_id); mem_attn->state_read(io, seq_id);
mem_recr->state_read(io, seq_id); mem_recr->state_read(io, seq_id);
} }

View file

@ -74,8 +74,8 @@ public:
// state write/load // state write/load
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override; void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
// //
// llama_memory_hybrid specific API // llama_memory_hybrid specific API

View file

@ -680,7 +680,9 @@ size_t llama_memory_recurrent::size_s_bytes() const {
return size_s_bytes; return size_s_bytes;
} }
void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id) const { void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
GGML_UNUSED(flags);
std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
uint32_t cell_count = 0; uint32_t cell_count = 0;
@ -718,7 +720,9 @@ void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq
state_write_data(io, cell_ranges); state_write_data(io, cell_ranges);
} }
void llama_memory_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id) { void llama_memory_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
GGML_UNUSED(flags);
uint32_t cell_count; uint32_t cell_count;
io.read_to(&cell_count, sizeof(cell_count)); io.read_to(&cell_count, sizeof(cell_count));

View file

@ -63,8 +63,8 @@ public:
// state write/load // state write/load
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override; void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot()) uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot())
uint32_t size = 0; // total number of cells, shared across all sequences uint32_t size = 0; // total number of cells, shared across all sequences

View file

@ -104,8 +104,8 @@ struct llama_memory_i {
// state write/read // state write/read
// //
virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const = 0; virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const = 0;
virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) = 0; virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) = 0;
}; };
using llama_memory_ptr = std::unique_ptr<llama_memory_i>; using llama_memory_ptr = std::unique_ptr<llama_memory_i>;

Binary file not shown.

View file

@ -692,6 +692,13 @@ struct completion_token_output {
} }
}; };
struct swa_checkpoint {
llama_pos pos_min;
llama_pos pos_max;
std::vector<uint8_t> data;
};
struct server_task_result_cmpl_final : server_task_result { struct server_task_result_cmpl_final : server_task_result {
int index = 0; int index = 0;
@ -1336,6 +1343,8 @@ struct server_slot {
std::vector<completion_token_output> generated_token_probs; std::vector<completion_token_output> generated_token_probs;
std::vector<swa_checkpoint> swa_checkpoints;
bool has_next_token = true; bool has_next_token = true;
bool has_new_line = false; bool has_new_line = false;
bool truncated = false; bool truncated = false;
@ -2015,6 +2024,10 @@ struct server_context {
params_dft.cache_type_k = params_base.speculative.cache_type_k; params_dft.cache_type_k = params_base.speculative.cache_type_k;
params_dft.cache_type_v = params_base.speculative.cache_type_v; params_dft.cache_type_v = params_base.speculative.cache_type_v;
params_dft.cpuparams.n_threads = params_base.speculative.cpuparams.n_threads;
params_dft.cpuparams_batch.n_threads = params_base.speculative.cpuparams_batch.n_threads;
params_dft.tensor_buft_overrides = params_base.speculative.tensor_buft_overrides;
llama_init_dft = common_init_from_params(params_dft); llama_init_dft = common_init_from_params(params_dft);
model_dft = llama_init_dft.model.get(); model_dft = llama_init_dft.model.get();
@ -3289,6 +3302,8 @@ struct server_context {
slot.n_past = 0; slot.n_past = 0;
} }
const auto n_swa = llama_model_n_swa(model);
if (slot.n_past > 0 && slot.n_past < (int) slot.cache_tokens.size()) { if (slot.n_past > 0 && slot.n_past < (int) slot.cache_tokens.size()) {
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id); const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
if (pos_min == -1) { if (pos_min == -1) {
@ -3296,12 +3311,58 @@ struct server_context {
GGML_ABORT("pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237"); GGML_ABORT("pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237");
} }
const auto n_swa = llama_model_n_swa(model); const auto pos_min_thold = std::max(0, slot.n_past - n_swa);
if (pos_min > std::max(0, slot.n_past - n_swa)) {
if (pos_min > pos_min_thold) {
SLT_WRN(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min, n_swa); SLT_WRN(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min, n_swa);
// search for a SWA checkpoint
const auto it = std::find_if(
slot.swa_checkpoints.rbegin(),
slot.swa_checkpoints.rend(),
[&](const auto & cur) {
return cur.pos_min <= pos_min_thold;
}
);
bool do_reset = it == slot.swa_checkpoints.rend();
if (!do_reset) {
// restore the checkpoint
const size_t swa_size = it->data.size();
const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), swa_size, slot.id, LLAMA_STATE_SEQ_FLAGS_SWA_ONLY);
if (n != swa_size) {
SLT_ERR(slot, "failed to restore SWA checkpoint, pos_min = %d, pos_max = %d, size = %.3f MiB\n", it->pos_min, it->pos_max, (float) swa_size / 1024 / 1024);
do_reset = true;
} else {
slot.n_past = std::min(slot.n_past, it->pos_max);
SLT_WRN(slot, "SWA checkpoint restore, pos_min = %d, pos_max = %d, size = %.3f MiB\n", it->pos_min, it->pos_max, (float) swa_size / 1024 / 1024);
}
}
if (do_reset) {
SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA, see %s)\n", SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA, see %s)\n",
"https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055"); "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
slot.n_past = 0; slot.n_past = 0;
slot.swa_checkpoints.clear();
}
}
}
if (n_swa > 0) {
const auto pos_min_thold = std::max(0, slot.n_past - n_swa);
// erase any checkpoints with pos_min > pos_min_thold
for (int i = (int) slot.swa_checkpoints.size() - 1; i >= 0; i--) {
const auto & cur = slot.swa_checkpoints[i];
if (cur.pos_min > pos_min_thold) {
slot.swa_checkpoints.erase(slot.swa_checkpoints.begin() + i);
SLT_WRN(slot, "SWA checkpoint erase, pos_min = %d, pos_max = %d, size = %.3f MiB\n", cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
}
} }
} }
} }
@ -3515,6 +3576,39 @@ struct server_context {
// prompt evaluated for next-token prediction // prompt evaluated for next-token prediction
slot.state = SLOT_STATE_GENERATING; slot.state = SLOT_STATE_GENERATING;
// make a checkpoint with the SWA memory
// checkpoints are needed only if we are not using "--swa-full"
if (llama_model_n_swa(model) > 0 && !params_base.swa_full && params_base.n_swa_checkpoints > 0) {
if (slot.swa_checkpoints.size() >= (size_t) params_base.n_swa_checkpoints) {
{
const auto & cur = slot.swa_checkpoints.back();
SLT_WRN(slot, "SWA checkpoint erase, pos_min = %d, pos_max = %d, size = %.3f MiB\n",
cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
}
slot.swa_checkpoints.erase(slot.swa_checkpoints.begin());
}
const size_t swa_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_SWA_ONLY);
auto & cur = slot.swa_checkpoints.emplace_back(swa_checkpoint{
/*.pos_min = */ llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id),
/*.pos_max = */ llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id),
/*.data = */ std::vector<uint8_t>(swa_size),
});
llama_state_seq_get_data_ext(ctx, cur.data.data(), swa_size, slot.id, LLAMA_STATE_SEQ_FLAGS_SWA_ONLY);
float size_total = 0.0f;
for (const auto & checkpoint : slot.swa_checkpoints) {
size_total += (float) checkpoint.data.size() / 1024 / 1024;
}
SLT_WRN(slot, "SWA checkpoint create, pos_min = %d, pos_max = %d, size = %.3f MiB, total = %d/%d (%.3f MiB)\n",
cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024, (int) slot.swa_checkpoints.size(), params_base.n_swa_checkpoints, size_total);
}
} else if (slot.state != SLOT_STATE_GENERATING) { } else if (slot.state != SLOT_STATE_GENERATING) {
continue; // continue loop of slots continue; // continue loop of slots
} }

View file

@ -130,7 +130,12 @@ export function filterThoughtFromMsgs(messages: APIMessage[]) {
role: msg.role, role: msg.role,
content: content:
msg.role === 'assistant' msg.role === 'assistant'
? contentStr.split('</think>').at(-1)!.trim() ? contentStr
.split(
/<\/think>|<\|start\|>assistant<\|channel\|>final<\|message\|>/
)
.at(-1)!
.trim()
: contentStr, : contentStr,
} as APIMessage; } as APIMessage;
}); });