mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-15 11:29:43 +00:00
hope i didnt break anything
This commit is contained in:
commit
7ac0102ed3
45 changed files with 2114 additions and 1090 deletions
43
.github/workflows/build-riscv-native.yml
vendored
Normal file
43
.github/workflows/build-riscv-native.yml
vendored
Normal file
|
@ -0,0 +1,43 @@
|
||||||
|
name: Build on RISCV Linux Machine by Cloud-V
|
||||||
|
on:
|
||||||
|
workflow_dispatch:
|
||||||
|
workflow_call:
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
bianbu-riscv64-native: # Bianbu 2.2
|
||||||
|
runs-on: self-hosted
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Install prerequisites
|
||||||
|
run: |
|
||||||
|
sudo apt-get update || true
|
||||||
|
sudo apt-get install -y libatomic1
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- name: Setup Riscv
|
||||||
|
run: |
|
||||||
|
sudo apt-get update || true
|
||||||
|
sudo apt-get install -y --no-install-recommends \
|
||||||
|
build-essential \
|
||||||
|
gcc-14-riscv64-linux-gnu \
|
||||||
|
g++-14-riscv64-linux-gnu \
|
||||||
|
cmake
|
||||||
|
|
||||||
|
- name: Build
|
||||||
|
run: |
|
||||||
|
cmake -B build -DLLAMA_CURL=OFF \
|
||||||
|
-DCMAKE_BUILD_TYPE=Release \
|
||||||
|
-DGGML_OPENMP=OFF \
|
||||||
|
-DLLAMA_BUILD_EXAMPLES=ON \
|
||||||
|
-DLLAMA_BUILD_TOOLS=ON \
|
||||||
|
-DLLAMA_BUILD_TESTS=OFF \
|
||||||
|
-DCMAKE_SYSTEM_NAME=Linux \
|
||||||
|
-DCMAKE_SYSTEM_PROCESSOR=riscv64 \
|
||||||
|
-DCMAKE_C_COMPILER=riscv64-linux-gnu-gcc-14 \
|
||||||
|
-DCMAKE_CXX_COMPILER=riscv64-linux-gnu-g++-14 \
|
||||||
|
-DCMAKE_POSITION_INDEPENDENT_CODE=ON \
|
||||||
|
-DCMAKE_FIND_ROOT_PATH=/usr/lib/riscv64-linux-gnu \
|
||||||
|
-DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \
|
||||||
|
-DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \
|
||||||
|
-DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH
|
||||||
|
|
||||||
|
cmake --build build --config Release -j $(nproc)
|
53
.github/workflows/copilot-setup-steps.yml
vendored
Normal file
53
.github/workflows/copilot-setup-steps.yml
vendored
Normal file
|
@ -0,0 +1,53 @@
|
||||||
|
name: "Copilot Setup Steps"
|
||||||
|
|
||||||
|
# Automatically run the setup steps when they are changed to allow for easy validation, and
|
||||||
|
# allow manual testing through the repository's "Actions" tab
|
||||||
|
on:
|
||||||
|
workflow_dispatch:
|
||||||
|
push:
|
||||||
|
paths:
|
||||||
|
- .github/workflows/copilot-setup-steps.yml
|
||||||
|
pull_request:
|
||||||
|
paths:
|
||||||
|
- .github/workflows/copilot-setup-steps.yml
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
# The job MUST be called `copilot-setup-steps` or it will not be picked up by Copilot.
|
||||||
|
copilot-setup-steps:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
# Set the permissions to the lowest permissions possible needed for your steps.
|
||||||
|
# Copilot will be given its own token for its operations.
|
||||||
|
permissions:
|
||||||
|
# If you want to clone the repository as part of your setup steps, for example to install dependencies, you'll need the `contents: read` permission. If you don't clone the repository in your setup steps, Copilot will do this for you automatically after the steps complete.
|
||||||
|
contents: read
|
||||||
|
|
||||||
|
# You can define any steps you want, and they will run before the agent starts.
|
||||||
|
# If you do not check out your code, Copilot will do this for you.
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: ccache
|
||||||
|
uses: hendrikmuhs/ccache-action@v1.2.16
|
||||||
|
with:
|
||||||
|
key: copilot-setup-steps
|
||||||
|
evict-old-files: 1d
|
||||||
|
|
||||||
|
- name: Dependencies
|
||||||
|
id: depends
|
||||||
|
run: |
|
||||||
|
sudo apt-get update
|
||||||
|
sudo apt-get install build-essential libcurl4-openssl-dev
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: '3.11'
|
||||||
|
|
||||||
|
- name: Install Python dependencies
|
||||||
|
run: |
|
||||||
|
python3 -m venv .venv
|
||||||
|
.venv/bin/activate
|
||||||
|
pip install -r requirements/requirements-all.txt -r tools/server/tests/requirements.txt
|
||||||
|
pip install flake8 pyright
|
158
common/arg.cpp
158
common/arg.cpp
|
@ -751,6 +751,39 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string &
|
||||||
// utils
|
// utils
|
||||||
//
|
//
|
||||||
|
|
||||||
|
// Helper function to parse tensor buffer override strings
|
||||||
|
static void parse_tensor_buffer_overrides(const std::string & value, std::vector<llama_model_tensor_buft_override> & overrides) {
|
||||||
|
std::map<std::string, ggml_backend_buffer_type_t> buft_list;
|
||||||
|
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
|
||||||
|
auto * dev = ggml_backend_dev_get(i);
|
||||||
|
auto * buft = ggml_backend_dev_buffer_type(dev);
|
||||||
|
if (buft) {
|
||||||
|
buft_list[ggml_backend_buft_name(buft)] = buft;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const auto & override : string_split<std::string>(value, ',')) {
|
||||||
|
std::string::size_type pos = override.find('=');
|
||||||
|
if (pos == std::string::npos) {
|
||||||
|
throw std::invalid_argument("invalid value");
|
||||||
|
}
|
||||||
|
std::string tensor_name = override.substr(0, pos);
|
||||||
|
std::string buffer_type = override.substr(pos + 1);
|
||||||
|
|
||||||
|
if (buft_list.find(buffer_type) == buft_list.end()) {
|
||||||
|
printf("Available buffer types:\n");
|
||||||
|
for (const auto & it : buft_list) {
|
||||||
|
printf(" %s\n", ggml_backend_buft_name(it.second));
|
||||||
|
}
|
||||||
|
throw std::invalid_argument("unknown buffer type");
|
||||||
|
}
|
||||||
|
// keep strings alive and avoid leaking memory by storing them in a static vector
|
||||||
|
static std::list<std::string> buft_overrides;
|
||||||
|
buft_overrides.push_back(tensor_name);
|
||||||
|
overrides.push_back({buft_overrides.back().c_str(), buft_list.at(buffer_type)});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
struct handle_model_result {
|
struct handle_model_result {
|
||||||
bool found_mmproj = false;
|
bool found_mmproj = false;
|
||||||
common_params_model mmproj;
|
common_params_model mmproj;
|
||||||
|
@ -995,6 +1028,10 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
|
||||||
params.tensor_buft_overrides.push_back({nullptr, nullptr});
|
params.tensor_buft_overrides.push_back({nullptr, nullptr});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!params.speculative.tensor_buft_overrides.empty()) {
|
||||||
|
params.speculative.tensor_buft_overrides.push_back({nullptr, nullptr});
|
||||||
|
}
|
||||||
|
|
||||||
if (!params.chat_template.empty() && !common_chat_verify_template(params.chat_template, params.use_jinja)) {
|
if (!params.chat_template.empty() && !common_chat_verify_template(params.chat_template, params.use_jinja)) {
|
||||||
throw std::runtime_error(string_format(
|
throw std::runtime_error(string_format(
|
||||||
"error: the supplied chat template is not supported: %s%s\n",
|
"error: the supplied chat template is not supported: %s%s\n",
|
||||||
|
@ -1203,6 +1240,7 @@ bool common_params_parse(int argc, char ** argv, common_params & params, llama_e
|
||||||
common_params_print_completion(ctx_arg);
|
common_params_print_completion(ctx_arg);
|
||||||
exit(0);
|
exit(0);
|
||||||
}
|
}
|
||||||
|
params.lr.init();
|
||||||
} catch (const std::invalid_argument & ex) {
|
} catch (const std::invalid_argument & ex) {
|
||||||
fprintf(stderr, "%s\n", ex.what());
|
fprintf(stderr, "%s\n", ex.what());
|
||||||
ctx_arg.params = params_org;
|
ctx_arg.params = params_org;
|
||||||
|
@ -1471,6 +1509,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||||
params.swa_full = true;
|
params.swa_full = true;
|
||||||
}
|
}
|
||||||
).set_env("LLAMA_ARG_SWA_FULL"));
|
).set_env("LLAMA_ARG_SWA_FULL"));
|
||||||
|
add_opt(common_arg(
|
||||||
|
{"--swa-checkpoints"}, "N",
|
||||||
|
string_format("max number of SWA checkpoints per slot to create (default: %d)\n"
|
||||||
|
"[(more info)](https://github.com/ggml-org/llama.cpp/pull/15293)", params.n_swa_checkpoints),
|
||||||
|
[](common_params & params, int value) {
|
||||||
|
params.n_swa_checkpoints = value;
|
||||||
|
}
|
||||||
|
).set_env("LLAMA_ARG_SWA_CHECKPOINTS").set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
{"--kv-unified", "-kvu"},
|
{"--kv-unified", "-kvu"},
|
||||||
string_format("use single unified KV buffer for the KV cache of all sequences (default: %s)\n"
|
string_format("use single unified KV buffer for the KV cache of all sequences (default: %s)\n"
|
||||||
|
@ -2351,40 +2397,15 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
{"--override-tensor", "-ot"}, "<tensor name pattern>=<buffer type>,...",
|
{"--override-tensor", "-ot"}, "<tensor name pattern>=<buffer type>,...",
|
||||||
"override tensor buffer type", [](common_params & params, const std::string & value) {
|
"override tensor buffer type", [](common_params & params, const std::string & value) {
|
||||||
/* static */ std::map<std::string, ggml_backend_buffer_type_t> buft_list;
|
parse_tensor_buffer_overrides(value, params.tensor_buft_overrides);
|
||||||
if (buft_list.empty()) {
|
|
||||||
// enumerate all the devices and add their buffer types to the list
|
|
||||||
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
|
|
||||||
auto * dev = ggml_backend_dev_get(i);
|
|
||||||
auto * buft = ggml_backend_dev_buffer_type(dev);
|
|
||||||
if (buft) {
|
|
||||||
buft_list[ggml_backend_buft_name(buft)] = buft;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (const auto & override : string_split<std::string>(value, ',')) {
|
|
||||||
std::string::size_type pos = override.find('=');
|
|
||||||
if (pos == std::string::npos) {
|
|
||||||
throw std::invalid_argument("invalid value");
|
|
||||||
}
|
|
||||||
std::string tensor_name = override.substr(0, pos);
|
|
||||||
std::string buffer_type = override.substr(pos + 1);
|
|
||||||
|
|
||||||
if (buft_list.find(buffer_type) == buft_list.end()) {
|
|
||||||
printf("Available buffer types:\n");
|
|
||||||
for (const auto & it : buft_list) {
|
|
||||||
printf(" %s\n", ggml_backend_buft_name(it.second));
|
|
||||||
}
|
|
||||||
throw std::invalid_argument("unknown buffer type");
|
|
||||||
}
|
|
||||||
// keep strings alive and avoid leaking memory by storing them in a static vector
|
|
||||||
static std::list<std::string> buft_overrides;
|
|
||||||
buft_overrides.push_back(tensor_name);
|
|
||||||
params.tensor_buft_overrides.push_back({buft_overrides.back().c_str(), buft_list.at(buffer_type)});
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
));
|
));
|
||||||
|
add_opt(common_arg(
|
||||||
|
{"--override-tensor-draft", "-otd"}, "<tensor name pattern>=<buffer type>,...",
|
||||||
|
"override tensor buffer type for draft model", [](common_params & params, const std::string & value) {
|
||||||
|
parse_tensor_buffer_overrides(value, params.speculative.tensor_buft_overrides);
|
||||||
|
}
|
||||||
|
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}));
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
{"--cpu-moe", "-cmoe"},
|
{"--cpu-moe", "-cmoe"},
|
||||||
"keep all Mixture of Experts (MoE) weights in the CPU",
|
"keep all Mixture of Experts (MoE) weights in the CPU",
|
||||||
|
@ -2407,6 +2428,27 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
).set_env("LLAMA_ARG_N_CPU_MOE"));
|
).set_env("LLAMA_ARG_N_CPU_MOE"));
|
||||||
|
add_opt(common_arg(
|
||||||
|
{"--cpu-moe-draft", "-cmoed"},
|
||||||
|
"keep all Mixture of Experts (MoE) weights in the CPU for the draft model",
|
||||||
|
[](common_params & params) {
|
||||||
|
params.speculative.tensor_buft_overrides.push_back({"\\.ffn_(up|down|gate)_exps", ggml_backend_cpu_buffer_type()});
|
||||||
|
}
|
||||||
|
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CPU_MOE_DRAFT"));
|
||||||
|
add_opt(common_arg(
|
||||||
|
{"--n-cpu-moe-draft", "-ncmoed"}, "N",
|
||||||
|
"keep the Mixture of Experts (MoE) weights of the first N layers in the CPU for the draft model",
|
||||||
|
[](common_params & params, int value) {
|
||||||
|
if (value < 0) {
|
||||||
|
throw std::invalid_argument("invalid value");
|
||||||
|
}
|
||||||
|
for (int i = 0; i < value; ++i) {
|
||||||
|
static std::list<std::string> buft_overrides_draft;
|
||||||
|
buft_overrides_draft.push_back(string_format("blk\\.%d\\.ffn_(up|down|gate)_exps", i));
|
||||||
|
params.speculative.tensor_buft_overrides.push_back({buft_overrides_draft.back().c_str(), ggml_backend_cpu_buffer_type()});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_N_CPU_MOE_DRAFT"));
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
{"-ngl", "--gpu-layers", "--n-gpu-layers"}, "N",
|
{"-ngl", "--gpu-layers", "--n-gpu-layers"}, "N",
|
||||||
"number of layers to store in VRAM",
|
"number of layers to store in VRAM",
|
||||||
|
@ -2657,7 +2699,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||||
[](common_params & params, const std::string & value) {
|
[](common_params & params, const std::string & value) {
|
||||||
params.out_file = value;
|
params.out_file = value;
|
||||||
}
|
}
|
||||||
).set_examples({LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_CVECTOR_GENERATOR, LLAMA_EXAMPLE_EXPORT_LORA, LLAMA_EXAMPLE_TTS}));
|
).set_examples({LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_CVECTOR_GENERATOR, LLAMA_EXAMPLE_EXPORT_LORA, LLAMA_EXAMPLE_TTS, LLAMA_EXAMPLE_FINETUNE}));
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
{"-ofreq", "--output-frequency"}, "N",
|
{"-ofreq", "--output-frequency"}, "N",
|
||||||
string_format("output the imatrix every N iterations (default: %d)", params.n_out_freq),
|
string_format("output the imatrix every N iterations (default: %d)", params.n_out_freq),
|
||||||
|
@ -3132,7 +3174,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||||
params.speculative.cpuparams.n_threads = std::thread::hardware_concurrency();
|
params.speculative.cpuparams.n_threads = std::thread::hardware_concurrency();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
|
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}));
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
{"-tbd", "--threads-batch-draft"}, "N",
|
{"-tbd", "--threads-batch-draft"}, "N",
|
||||||
"number of threads to use during batch and prompt processing (default: same as --threads-draft)",
|
"number of threads to use during batch and prompt processing (default: same as --threads-draft)",
|
||||||
|
@ -3142,7 +3184,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||||
params.speculative.cpuparams_batch.n_threads = std::thread::hardware_concurrency();
|
params.speculative.cpuparams_batch.n_threads = std::thread::hardware_concurrency();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
|
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}));
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
{"-Cd", "--cpu-mask-draft"}, "M",
|
{"-Cd", "--cpu-mask-draft"}, "M",
|
||||||
"Draft model CPU affinity mask. Complements cpu-range-draft (default: same as --cpu-mask)",
|
"Draft model CPU affinity mask. Complements cpu-range-draft (default: same as --cpu-mask)",
|
||||||
|
@ -3535,5 +3577,51 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||||
|
|
||||||
|
|
||||||
|
add_opt(
|
||||||
|
common_arg({ "-lr", "--learning-rate" }, "ALPHA",
|
||||||
|
string_format(
|
||||||
|
"adamw or sgd optimizer alpha (default: %.2g); note: sgd alpha recommended ~10x (no momentum)",
|
||||||
|
(double) params.lr.lr0),
|
||||||
|
[](common_params & params, const std::string & value) { params.lr.lr0 = std::stof(value); })
|
||||||
|
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
|
||||||
|
add_opt(
|
||||||
|
common_arg({ "-lr-min", "--learning-rate-min" }, "ALPHA",
|
||||||
|
string_format(
|
||||||
|
"(if >0) final learning rate after decay (if -decay-epochs is set, default=%.2g)",
|
||||||
|
(double) params.lr.lr_min),
|
||||||
|
[](common_params & params, const std::string & value) { params.lr.lr_min = std::stof(value); })
|
||||||
|
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
|
||||||
|
add_opt(
|
||||||
|
common_arg({ "-decay-epochs", "--learning-rate-decay-epochs" }, "ALPHA",
|
||||||
|
string_format(
|
||||||
|
"(if >0) decay learning rate to -lr-min after this many epochs (exponential decay, default=%.2g)",
|
||||||
|
(double) params.lr.decay_epochs),
|
||||||
|
[](common_params & params, const std::string & value) { params.lr.decay_epochs = std::stof(value); })
|
||||||
|
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
|
||||||
|
add_opt(common_arg(
|
||||||
|
{ "-wd", "--weight-decay" }, "WD",
|
||||||
|
string_format(
|
||||||
|
"adamw or sgd optimizer weight decay (0 is off; recommend very small e.g. 1e-9) (default: %.2g).",
|
||||||
|
(double) params.lr.wd),
|
||||||
|
[](common_params & params, const std::string & value) { params.lr.wd = std::stof(value); })
|
||||||
|
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
|
||||||
|
add_opt(common_arg({ "-val-split", "--val-split" }, "FRACTION",
|
||||||
|
string_format("fraction of data to use as validation set for training (default: %.2g).",
|
||||||
|
(double) params.val_split),
|
||||||
|
[](common_params & params, const std::string & value) { params.val_split = std::stof(value); })
|
||||||
|
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
|
||||||
|
add_opt(common_arg({ "-epochs", "--epochs" }, "N",
|
||||||
|
string_format("optimizer max # of epochs (default: %d)", params.lr.epochs),
|
||||||
|
[](common_params & params, int epochs) { params.lr.epochs = epochs; })
|
||||||
|
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
|
||||||
|
add_opt(common_arg({ "-opt", "--optimizer" }, "sgd|adamw", "adamw or sgd",
|
||||||
|
[](common_params & params, const std::string & name) {
|
||||||
|
params.optimizer = common_opt_get_optimizer(name.c_str());
|
||||||
|
if (params.optimizer == GGML_OPT_OPTIMIZER_TYPE_COUNT) {
|
||||||
|
throw std::invalid_argument("invalid --optimizer, valid options: adamw, sgd");
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
|
||||||
|
|
||||||
return ctx_arg;
|
return ctx_arg;
|
||||||
}
|
}
|
||||||
|
|
|
@ -49,6 +49,7 @@
|
||||||
#endif
|
#endif
|
||||||
#include <locale>
|
#include <locale>
|
||||||
#include <windows.h>
|
#include <windows.h>
|
||||||
|
#include <string.h>
|
||||||
#include <fcntl.h>
|
#include <fcntl.h>
|
||||||
#include <io.h>
|
#include <io.h>
|
||||||
#else
|
#else
|
||||||
|
@ -1573,3 +1574,56 @@ ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ggml_opt_optimizer_params common_opt_lr_pars(void * userdata) {
|
||||||
|
ggml_opt_optimizer_params result = ggml_opt_get_default_optimizer_params(nullptr);
|
||||||
|
const lr_opt & d = *(lr_opt *) userdata;
|
||||||
|
result.adamw.alpha = result.sgd.alpha = d.get_lr(d.epoch);
|
||||||
|
result.sgd.wd = result.adamw.wd = d.wd;
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO make all command line args case-insensitive
|
||||||
|
static inline bool eq_case_insensitive(char const* a, char const* b) {
|
||||||
|
return !
|
||||||
|
#if defined(_MSC_VER)
|
||||||
|
_stricmp
|
||||||
|
#else
|
||||||
|
strcasecmp
|
||||||
|
#endif // defined(_MSC_VER)
|
||||||
|
(a, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
enum ggml_opt_optimizer_type common_opt_get_optimizer(const char * n) {
|
||||||
|
if (eq_case_insensitive("adamw", n)) {
|
||||||
|
return GGML_OPT_OPTIMIZER_TYPE_ADAMW;
|
||||||
|
}
|
||||||
|
if (eq_case_insensitive("sgd", n)) {
|
||||||
|
return GGML_OPT_OPTIMIZER_TYPE_SGD;
|
||||||
|
}
|
||||||
|
return GGML_OPT_OPTIMIZER_TYPE_COUNT;
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO simplify to use just log and exp
|
||||||
|
static float const k_log_2 = std::log(2.f);
|
||||||
|
|
||||||
|
void lr_opt::init() {
|
||||||
|
if (lr_min > 0 && lr_min < lr0) {
|
||||||
|
float nhalf = std::log(lr0 / lr_min) / k_log_2;
|
||||||
|
float e = epochs;
|
||||||
|
if (decay_epochs > 0 && decay_epochs < e) {
|
||||||
|
e = decay_epochs;
|
||||||
|
} else {
|
||||||
|
decay_epochs = e;
|
||||||
|
}
|
||||||
|
scale_epoch = nhalf / e;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
float lr_opt::get_lr(float epoch) const {
|
||||||
|
float r = lr_min <= 0 ? lr0 :
|
||||||
|
epoch >= decay_epochs ? lr_min :
|
||||||
|
lr0 * std::pow(0.5f, epoch * scale_epoch);
|
||||||
|
LOG_INF("epoch %.2g lr=%.2g\n", epoch, r);
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
|
|
@ -2,14 +2,17 @@
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "llama-cpp.h"
|
|
||||||
|
|
||||||
#include <set>
|
#include <set>
|
||||||
|
#include <sstream>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <string_view>
|
#include <string_view>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
#include <cmath>
|
||||||
|
|
||||||
|
#include "ggml-opt.h"
|
||||||
|
#include "llama-cpp.h"
|
||||||
|
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
#define DIRECTORY_SEPARATOR '\\'
|
#define DIRECTORY_SEPARATOR '\\'
|
||||||
|
@ -78,6 +81,7 @@ enum llama_example {
|
||||||
LLAMA_EXAMPLE_PARALLEL,
|
LLAMA_EXAMPLE_PARALLEL,
|
||||||
LLAMA_EXAMPLE_TTS,
|
LLAMA_EXAMPLE_TTS,
|
||||||
LLAMA_EXAMPLE_DIFFUSION,
|
LLAMA_EXAMPLE_DIFFUSION,
|
||||||
|
LLAMA_EXAMPLE_FINETUNE,
|
||||||
|
|
||||||
LLAMA_EXAMPLE_COUNT,
|
LLAMA_EXAMPLE_COUNT,
|
||||||
};
|
};
|
||||||
|
@ -198,6 +202,7 @@ struct common_params_speculative {
|
||||||
float p_split = 0.1f; // speculative decoding split probability
|
float p_split = 0.1f; // speculative decoding split probability
|
||||||
float p_min = 0.75f; // minimum speculative decoding probability (greedy)
|
float p_min = 0.75f; // minimum speculative decoding probability (greedy)
|
||||||
std::vector<std::pair<std::string, std::string>> replacements; // main to speculative model replacements
|
std::vector<std::pair<std::string, std::string>> replacements; // main to speculative model replacements
|
||||||
|
std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;
|
||||||
|
|
||||||
ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K
|
ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K
|
||||||
ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V
|
ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V
|
||||||
|
@ -238,6 +243,25 @@ enum common_reasoning_format {
|
||||||
COMMON_REASONING_FORMAT_GRANITE, // Extract thinking tag contents and return as `message.reasoning_content`, including in streaming deltas.
|
COMMON_REASONING_FORMAT_GRANITE, // Extract thinking tag contents and return as `message.reasoning_content`, including in streaming deltas.
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
struct lr_opt {
|
||||||
|
float lr0 = 1e-5; // learning rate at first epoch
|
||||||
|
float lr_min = -1;
|
||||||
|
float decay_epochs = -1; // if >0, the learning rate starts at lr0 and decays to lr_min after this many epochs
|
||||||
|
float scale_epoch = 0;
|
||||||
|
float wd = 0;
|
||||||
|
unsigned epochs = 2;
|
||||||
|
|
||||||
|
unsigned epoch; // set by optimizer outer (epochs) loop
|
||||||
|
// learning rate decay - constant LR per epoch only for now
|
||||||
|
float get_lr(float e) const;
|
||||||
|
float get_lr() const { return get_lr(epoch); }
|
||||||
|
// must call after arg parse, before get_lr
|
||||||
|
void init();
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ggml_opt_optimizer_params common_opt_lr_pars(void * userdata);
|
||||||
|
|
||||||
struct common_params {
|
struct common_params {
|
||||||
int32_t n_predict = -1; // new tokens to predict
|
int32_t n_predict = -1; // new tokens to predict
|
||||||
int32_t n_ctx = 4096; // context size
|
int32_t n_ctx = 4096; // context size
|
||||||
|
@ -372,6 +396,11 @@ struct common_params {
|
||||||
bool no_mmproj = false; // explicitly disable multimodal model
|
bool no_mmproj = false; // explicitly disable multimodal model
|
||||||
std::vector<std::string> image; // path to image file(s)
|
std::vector<std::string> image; // path to image file(s)
|
||||||
|
|
||||||
|
// finetune
|
||||||
|
struct lr_opt lr;
|
||||||
|
enum ggml_opt_optimizer_type optimizer = GGML_OPT_OPTIMIZER_TYPE_ADAMW;
|
||||||
|
float val_split = 0.05f; // fraction of the data used for the validation set
|
||||||
|
|
||||||
// embedding
|
// embedding
|
||||||
bool embedding = false; // get only sentence embedding
|
bool embedding = false; // get only sentence embedding
|
||||||
int32_t embd_normalize = 2; // normalisation for embeddings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)
|
int32_t embd_normalize = 2; // normalisation for embeddings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)
|
||||||
|
@ -385,6 +414,7 @@ struct common_params {
|
||||||
int32_t timeout_write = timeout_read; // http write timeout in seconds
|
int32_t timeout_write = timeout_read; // http write timeout in seconds
|
||||||
int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool)
|
int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool)
|
||||||
int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting
|
int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting
|
||||||
|
int32_t n_swa_checkpoints = 3; // max number of SWA checkpoints per slot
|
||||||
|
|
||||||
std::string hostname = "127.0.0.1";
|
std::string hostname = "127.0.0.1";
|
||||||
std::string public_path = ""; // NOLINT
|
std::string public_path = ""; // NOLINT
|
||||||
|
@ -699,3 +729,6 @@ const char * const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count";
|
||||||
//
|
//
|
||||||
|
|
||||||
ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std::vector<llama_token> & tokens, int64_t stride);
|
ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std::vector<llama_token> & tokens, int64_t stride);
|
||||||
|
|
||||||
|
// "adamw" or "sgd" (case insensitive)
|
||||||
|
enum ggml_opt_optimizer_type common_opt_get_optimizer(const char *);
|
||||||
|
|
|
@ -74,16 +74,26 @@ extern "C" {
|
||||||
GGML_OPT_BUILD_TYPE_OPT = 30,
|
GGML_OPT_BUILD_TYPE_OPT = 30,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
enum ggml_opt_optimizer_type {
|
||||||
|
GGML_OPT_OPTIMIZER_TYPE_ADAMW,
|
||||||
|
GGML_OPT_OPTIMIZER_TYPE_SGD,
|
||||||
|
|
||||||
|
GGML_OPT_OPTIMIZER_TYPE_COUNT
|
||||||
|
};
|
||||||
|
|
||||||
// parameters that control which optimizer is used and how said optimizer tries to find the minimal loss
|
// parameters that control which optimizer is used and how said optimizer tries to find the minimal loss
|
||||||
struct ggml_opt_optimizer_params {
|
struct ggml_opt_optimizer_params {
|
||||||
// AdamW optimizer parameters
|
|
||||||
struct {
|
struct {
|
||||||
float alpha; // learning rate
|
float alpha; // learning rate
|
||||||
float beta1;
|
float beta1; // first AdamW momentum
|
||||||
float beta2;
|
float beta2; // second AdamW momentum
|
||||||
float eps; // epsilon for numerical stability
|
float eps; // epsilon for numerical stability
|
||||||
float wd; // weight decay for AdamW, use 0.0f to disable
|
float wd; // weight decay - 0.0f to disable
|
||||||
} adamw;
|
} adamw;
|
||||||
|
struct {
|
||||||
|
float alpha; // learning rate
|
||||||
|
float wd; // weight decay
|
||||||
|
} sgd;
|
||||||
};
|
};
|
||||||
|
|
||||||
// callback to calculate optimizer parameters prior to a backward pass
|
// callback to calculate optimizer parameters prior to a backward pass
|
||||||
|
@ -114,6 +124,9 @@ extern "C" {
|
||||||
|
|
||||||
ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters
|
ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters
|
||||||
void * get_opt_pars_ud; // userdata for calculating optimizer parameters
|
void * get_opt_pars_ud; // userdata for calculating optimizer parameters
|
||||||
|
|
||||||
|
// only GGML_OPT_OPTIMIZER_TYPE_ADAMW needs m, v momenta per parameter tensor
|
||||||
|
enum ggml_opt_optimizer_type optimizer;
|
||||||
};
|
};
|
||||||
|
|
||||||
// get parameters for an optimization context with defaults set where possible
|
// get parameters for an optimization context with defaults set where possible
|
||||||
|
@ -142,6 +155,10 @@ extern "C" {
|
||||||
// get the gradient accumulator for a node from the forward graph
|
// get the gradient accumulator for a node from the forward graph
|
||||||
GGML_API struct ggml_tensor * ggml_opt_grad_acc(ggml_opt_context_t opt_ctx, struct ggml_tensor * node);
|
GGML_API struct ggml_tensor * ggml_opt_grad_acc(ggml_opt_context_t opt_ctx, struct ggml_tensor * node);
|
||||||
|
|
||||||
|
GGML_API enum ggml_opt_optimizer_type ggml_opt_context_optimizer_type(ggml_opt_context_t); //TODO consistent naming scheme
|
||||||
|
|
||||||
|
GGML_API const char * ggml_opt_optimizer_name(enum ggml_opt_optimizer_type);
|
||||||
|
|
||||||
// ====== Optimization Result ======
|
// ====== Optimization Result ======
|
||||||
|
|
||||||
GGML_API ggml_opt_result_t ggml_opt_result_init(void);
|
GGML_API ggml_opt_result_t ggml_opt_result_init(void);
|
||||||
|
@ -226,12 +243,14 @@ extern "C" {
|
||||||
struct ggml_tensor * outputs, // output tensor, must have shape [ne_label, ndata_batch] if labels are used
|
struct ggml_tensor * outputs, // output tensor, must have shape [ne_label, ndata_batch] if labels are used
|
||||||
ggml_opt_dataset_t dataset, // dataset with data and optionally also labels
|
ggml_opt_dataset_t dataset, // dataset with data and optionally also labels
|
||||||
enum ggml_opt_loss_type loss_type, // loss to minimize
|
enum ggml_opt_loss_type loss_type, // loss to minimize
|
||||||
|
enum ggml_opt_optimizer_type optimizer, // sgd or adamw
|
||||||
ggml_opt_get_optimizer_params get_opt_pars, // callback to get optimizer params, userdata is pointer to epoch (of type int64_t)
|
ggml_opt_get_optimizer_params get_opt_pars, // callback to get optimizer params, userdata is pointer to epoch (of type int64_t)
|
||||||
int64_t nepoch, // how many times the dataset should be iterated over
|
int64_t nepoch, // how many times the dataset should be iterated over
|
||||||
int64_t nbatch_logical, // datapoints optimizer step, must be a multiple of ndata_batch in inputs/outputs
|
int64_t nbatch_logical, // datapoints optimizer step, must be a multiple of ndata_batch in inputs/outputs
|
||||||
float val_split, // fraction of the dataset to use for validation, must be in [0.0f, 1.0f)
|
float val_split, // fraction of the dataset to use for validation, must be in [0.0f, 1.0f)
|
||||||
bool silent); // whether or not info prints to stderr should be suppressed
|
bool silent); // whether or not info prints to stderr should be suppressed
|
||||||
|
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -241,6 +241,8 @@
|
||||||
#define GGML_ROPE_TYPE_MROPE 8
|
#define GGML_ROPE_TYPE_MROPE 8
|
||||||
#define GGML_ROPE_TYPE_VISION 24
|
#define GGML_ROPE_TYPE_VISION 24
|
||||||
|
|
||||||
|
#define GGML_MROPE_SECTIONS 4
|
||||||
|
|
||||||
#define GGML_UNUSED(x) (void)(x)
|
#define GGML_UNUSED(x) (void)(x)
|
||||||
|
|
||||||
#define GGML_PAD(x, n) (((x) + (n) - 1) & ~((n) - 1))
|
#define GGML_PAD(x, n) (((x) + (n) - 1) & ~((n) - 1))
|
||||||
|
@ -546,6 +548,7 @@ extern "C" {
|
||||||
GGML_OP_CROSS_ENTROPY_LOSS,
|
GGML_OP_CROSS_ENTROPY_LOSS,
|
||||||
GGML_OP_CROSS_ENTROPY_LOSS_BACK,
|
GGML_OP_CROSS_ENTROPY_LOSS_BACK,
|
||||||
GGML_OP_OPT_STEP_ADAMW,
|
GGML_OP_OPT_STEP_ADAMW,
|
||||||
|
GGML_OP_OPT_STEP_SGD,
|
||||||
|
|
||||||
GGML_OP_GLU,
|
GGML_OP_GLU,
|
||||||
|
|
||||||
|
@ -1673,7 +1676,7 @@ extern "C" {
|
||||||
struct ggml_tensor * b,
|
struct ggml_tensor * b,
|
||||||
struct ggml_tensor * c,
|
struct ggml_tensor * c,
|
||||||
int n_dims,
|
int n_dims,
|
||||||
int sections[4],
|
int sections[GGML_MROPE_SECTIONS],
|
||||||
int mode,
|
int mode,
|
||||||
int n_ctx_orig,
|
int n_ctx_orig,
|
||||||
float freq_base,
|
float freq_base,
|
||||||
|
@ -1699,6 +1702,22 @@ extern "C" {
|
||||||
float beta_fast,
|
float beta_fast,
|
||||||
float beta_slow);
|
float beta_slow);
|
||||||
|
|
||||||
|
GGML_API struct ggml_tensor * ggml_rope_multi_inplace(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
struct ggml_tensor * b,
|
||||||
|
struct ggml_tensor * c,
|
||||||
|
int n_dims,
|
||||||
|
int sections[GGML_MROPE_SECTIONS],
|
||||||
|
int mode,
|
||||||
|
int n_ctx_orig,
|
||||||
|
float freq_base,
|
||||||
|
float freq_scale,
|
||||||
|
float ext_factor,
|
||||||
|
float attn_factor,
|
||||||
|
float beta_fast,
|
||||||
|
float beta_slow);
|
||||||
|
|
||||||
GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_rope_custom(
|
GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_rope_custom(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
|
@ -2306,7 +2325,14 @@ extern "C" {
|
||||||
struct ggml_tensor * grad,
|
struct ggml_tensor * grad,
|
||||||
struct ggml_tensor * m,
|
struct ggml_tensor * m,
|
||||||
struct ggml_tensor * v,
|
struct ggml_tensor * v,
|
||||||
struct ggml_tensor * adamw_params); // parameters such a the learning rate
|
struct ggml_tensor * adamw_params); // parameters such as the learning rate
|
||||||
|
|
||||||
|
// stochastic gradient descent step (with weight decay)
|
||||||
|
GGML_API struct ggml_tensor * ggml_opt_step_sgd(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
struct ggml_tensor * grad,
|
||||||
|
struct ggml_tensor * sgd_params); // alpha, weight decay
|
||||||
|
|
||||||
//
|
//
|
||||||
// automatic differentiation
|
// automatic differentiation
|
||||||
|
|
|
@ -40,18 +40,22 @@
|
||||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||||
|
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
||||||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||||
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
||||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||||
|
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
||||||
#elif defined(__aarch64__) || defined(__arm__) || defined(_M_ARM) || defined(_M_ARM64)
|
#elif defined(__aarch64__) || defined(__arm__) || defined(_M_ARM) || defined(_M_ARM64)
|
||||||
// repack.cpp
|
// repack.cpp
|
||||||
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
|
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
|
||||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||||
|
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
||||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||||
|
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
||||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||||
#elif defined(__x86_64__) || defined(__i386__) || defined(_M_IX86) || defined(_M_X64)
|
#elif defined(__x86_64__) || defined(__i386__) || defined(_M_IX86) || defined(_M_X64)
|
||||||
// repack.cpp
|
// repack.cpp
|
||||||
|
@ -80,12 +84,14 @@
|
||||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||||
|
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
||||||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||||
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
||||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||||
|
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
||||||
#elif defined(__loongarch64)
|
#elif defined(__loongarch64)
|
||||||
// quants.c
|
// quants.c
|
||||||
#define quantize_row_q8_K_generic quantize_row_q8_K
|
#define quantize_row_q8_K_generic quantize_row_q8_K
|
||||||
|
@ -103,12 +109,14 @@
|
||||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||||
|
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
||||||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||||
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
||||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||||
|
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
||||||
#elif defined(__riscv)
|
#elif defined(__riscv)
|
||||||
// quants.c
|
// quants.c
|
||||||
#define quantize_row_q8_K_generic quantize_row_q8_K
|
#define quantize_row_q8_K_generic quantize_row_q8_K
|
||||||
|
@ -133,11 +141,13 @@
|
||||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||||
|
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
||||||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||||
|
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
||||||
#elif defined(__s390x__)
|
#elif defined(__s390x__)
|
||||||
// quants.c
|
// quants.c
|
||||||
#define quantize_row_q8_K_generic quantize_row_q8_K
|
#define quantize_row_q8_K_generic quantize_row_q8_K
|
||||||
|
@ -164,12 +174,14 @@
|
||||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||||
|
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
||||||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||||
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
||||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||||
|
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
||||||
#elif defined(__wasm__)
|
#elif defined(__wasm__)
|
||||||
// quants.c
|
// quants.c
|
||||||
#define ggml_vec_dot_q4_1_q8_1_generic ggml_vec_dot_q4_1_q8_1
|
#define ggml_vec_dot_q4_1_q8_1_generic ggml_vec_dot_q4_1_q8_1
|
||||||
|
@ -195,10 +207,12 @@
|
||||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||||
|
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
||||||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||||
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
||||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||||
|
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
||||||
#endif
|
#endif
|
||||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -2036,6 +2036,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||||||
ggml_compute_forward_opt_step_adamw(params, tensor);
|
ggml_compute_forward_opt_step_adamw(params, tensor);
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
case GGML_OP_OPT_STEP_SGD:
|
||||||
|
{
|
||||||
|
ggml_compute_forward_opt_step_sgd(params, tensor);
|
||||||
|
}
|
||||||
|
break;
|
||||||
case GGML_OP_NONE:
|
case GGML_OP_NONE:
|
||||||
{
|
{
|
||||||
// nop
|
// nop
|
||||||
|
@ -2339,6 +2344,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
||||||
case GGML_OP_CROSS_ENTROPY_LOSS:
|
case GGML_OP_CROSS_ENTROPY_LOSS:
|
||||||
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
||||||
case GGML_OP_OPT_STEP_ADAMW:
|
case GGML_OP_OPT_STEP_ADAMW:
|
||||||
|
case GGML_OP_OPT_STEP_SGD:
|
||||||
{
|
{
|
||||||
n_tasks = n_threads;
|
n_tasks = n_threads;
|
||||||
} break;
|
} break;
|
||||||
|
|
|
@ -10330,6 +10330,7 @@ static void ggml_compute_forward_opt_step_adamw_f32(
|
||||||
const int ir1 = MIN(ir0 + dr, nr);
|
const int ir1 = MIN(ir0 + dr, nr);
|
||||||
|
|
||||||
const float * adamw_params_ptr = ggml_get_data_f32(adamw_params);
|
const float * adamw_params_ptr = ggml_get_data_f32(adamw_params);
|
||||||
|
|
||||||
const float alpha = adamw_params_ptr[0];
|
const float alpha = adamw_params_ptr[0];
|
||||||
const float beta1 = adamw_params_ptr[1];
|
const float beta1 = adamw_params_ptr[1];
|
||||||
const float beta2 = adamw_params_ptr[2];
|
const float beta2 = adamw_params_ptr[2];
|
||||||
|
@ -10337,7 +10338,7 @@ static void ggml_compute_forward_opt_step_adamw_f32(
|
||||||
const float wd = adamw_params_ptr[4];
|
const float wd = adamw_params_ptr[4];
|
||||||
const float beta1h = adamw_params_ptr[5];
|
const float beta1h = adamw_params_ptr[5];
|
||||||
const float beta2h = adamw_params_ptr[6];
|
const float beta2h = adamw_params_ptr[6];
|
||||||
|
const float keep = 1.f - alpha * wd;
|
||||||
for (int ir = ir0; ir < ir1; ++ir) {
|
for (int ir = ir0; ir < ir1; ++ir) {
|
||||||
const int64_t i03 = ir/(ne02*ne01);
|
const int64_t i03 = ir/(ne02*ne01);
|
||||||
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
|
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
|
||||||
|
@ -10360,7 +10361,7 @@ static void ggml_compute_forward_opt_step_adamw_f32(
|
||||||
// The weight decay is applied independently of the Adam momenta m and v.
|
// The weight decay is applied independently of the Adam momenta m and v.
|
||||||
// This is NOT equivalent to l2 regularization that adds w[i00]*w[i00] to the loss.
|
// This is NOT equivalent to l2 regularization that adds w[i00]*w[i00] to the loss.
|
||||||
// See: https://arxiv.org/pdf/1711.05101v3.pdf
|
// See: https://arxiv.org/pdf/1711.05101v3.pdf
|
||||||
w[i00] = w[i00]*(1.0f - alpha*wd) - alpha*mh/vh;
|
w[i00] = w[i00] * keep - alpha * mh / vh;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -10382,3 +10383,63 @@ void ggml_compute_forward_opt_step_adamw(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void ggml_compute_forward_opt_step_sgd_f32(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||||
|
const ggml_tensor * src0 = dst->src[0];
|
||||||
|
const ggml_tensor * src0_grad = dst->src[1];
|
||||||
|
const ggml_tensor * sgd_params = dst->src[2];
|
||||||
|
|
||||||
|
GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
|
||||||
|
GGML_ASSERT(ggml_nelements(sgd_params) == 2);
|
||||||
|
|
||||||
|
const int ith = params->ith;
|
||||||
|
const int nth = params->nth;
|
||||||
|
|
||||||
|
const int nr = ggml_nrows(src0);
|
||||||
|
|
||||||
|
GGML_TENSOR_UNARY_OP_LOCALS
|
||||||
|
GGML_ASSERT(nb00 == sizeof(float));
|
||||||
|
|
||||||
|
// rows per thread
|
||||||
|
const int dr = (nr + nth - 1) / nth;
|
||||||
|
|
||||||
|
// row range for this thread
|
||||||
|
const int ir0 = dr * ith;
|
||||||
|
const int ir1 = MIN(ir0 + dr, nr);
|
||||||
|
|
||||||
|
// using adamw param subset we care about - alpha, wd - could have a separate struct
|
||||||
|
const float * sgd_params_ptr = ggml_get_data_f32(sgd_params);
|
||||||
|
const float alpha = sgd_params_ptr[0];
|
||||||
|
const float keep = 1.f - alpha * sgd_params_ptr[1];
|
||||||
|
|
||||||
|
for (int ir = ir0; ir < ir1; ++ir) {
|
||||||
|
const int64_t i03 = ir / (ne02 * ne01);
|
||||||
|
const int64_t i02 = (ir - i03 * ne02 * ne01) / ne01;
|
||||||
|
const int64_t i01 = (ir - i03 * ne02 * ne01 - i02 * ne01);
|
||||||
|
|
||||||
|
const size_t offset = i03 * nb03 + i02 * nb02 + i01 * nb01;
|
||||||
|
|
||||||
|
float * w = (float *) ((char *) src0->data + offset); // weight
|
||||||
|
const float * g = (const float *) ((const char *) src0_grad->data + offset); // grad
|
||||||
|
|
||||||
|
for (int i00 = 0; i00 < ne00; ++i00) {
|
||||||
|
w[i00] = w[i00] * keep - alpha * g[i00];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_compute_forward_opt_step_sgd(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||||
|
const ggml_tensor * src0 = dst->src[0];
|
||||||
|
|
||||||
|
switch (src0->type) {
|
||||||
|
case GGML_TYPE_F32:
|
||||||
|
{
|
||||||
|
ggml_compute_forward_opt_step_sgd_f32(params, dst);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
{
|
||||||
|
GGML_ABORT("fatal error - sgd is F32 only");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -107,7 +107,7 @@ void ggml_compute_forward_cross_entropy_loss(const struct ggml_compute_params *
|
||||||
void ggml_compute_forward_cross_entropy_loss_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
void ggml_compute_forward_cross_entropy_loss_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
void ggml_compute_forward_opt_step_adamw(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
void ggml_compute_forward_opt_step_adamw(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
void ggml_compute_forward_mul_mat(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
void ggml_compute_forward_mul_mat(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
|
void ggml_compute_forward_opt_step_sgd(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -206,6 +206,7 @@ void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs,
|
||||||
const int ncols_interleaved = 4;
|
const int ncols_interleaved = 4;
|
||||||
const int blocklen = 4;
|
const int blocklen = 4;
|
||||||
|
|
||||||
|
assert(nr == 1);
|
||||||
assert(n % qk == 0);
|
assert(n % qk == 0);
|
||||||
assert(nc % ncols_interleaved == 0);
|
assert(nc % ncols_interleaved == 0);
|
||||||
|
|
||||||
|
@ -307,7 +308,6 @@ void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs,
|
||||||
UNUSED(ncols_interleaved);
|
UNUSED(ncols_interleaved);
|
||||||
UNUSED(blocklen);
|
UNUSED(blocklen);
|
||||||
|
|
||||||
{
|
|
||||||
float sumf[8];
|
float sumf[8];
|
||||||
int sumi;
|
int sumi;
|
||||||
|
|
||||||
|
@ -332,7 +332,6 @@ void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs,
|
||||||
for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
|
for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||||
const int qk = QK_K;
|
const int qk = QK_K;
|
||||||
|
@ -494,20 +493,13 @@ void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs
|
||||||
const int ncols_interleaved = 4;
|
const int ncols_interleaved = 4;
|
||||||
const int blocklen = 4;
|
const int blocklen = 4;
|
||||||
|
|
||||||
|
assert(nr == 1);
|
||||||
assert(n % qk == 0);
|
assert(n % qk == 0);
|
||||||
assert(nc % ncols_interleaved == 0);
|
assert(nc % ncols_interleaved == 0);
|
||||||
|
|
||||||
UNUSED(s);
|
|
||||||
UNUSED(bs);
|
UNUSED(bs);
|
||||||
UNUSED(vx);
|
|
||||||
UNUSED(vy);
|
|
||||||
UNUSED(nr);
|
UNUSED(nr);
|
||||||
UNUSED(nc);
|
|
||||||
UNUSED(nb);
|
|
||||||
UNUSED(ncols_interleaved);
|
|
||||||
UNUSED(blocklen);
|
|
||||||
|
|
||||||
{
|
|
||||||
float sumf[4];
|
float sumf[4];
|
||||||
int sumi;
|
int sumi;
|
||||||
|
|
||||||
|
@ -532,6 +524,43 @@ void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs
|
||||||
for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
|
for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||||
|
const int qk = QK8_0;
|
||||||
|
const int nb = n / qk;
|
||||||
|
const int ncols_interleaved = 8;
|
||||||
|
const int blocklen = 8;
|
||||||
|
|
||||||
|
assert(nr == 1);
|
||||||
|
assert(n % qk == 0);
|
||||||
|
assert(nc % ncols_interleaved == 0);
|
||||||
|
|
||||||
|
UNUSED(bs);
|
||||||
|
UNUSED(nr);
|
||||||
|
|
||||||
|
float sumf[8];
|
||||||
|
int sumi;
|
||||||
|
|
||||||
|
const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
|
||||||
|
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||||
|
const block_iq4_nlx8 * b_ptr = (const block_iq4_nlx8 *) vx + (x * nb);
|
||||||
|
|
||||||
|
for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
|
||||||
|
for (int l = 0; l < nb; l++) {
|
||||||
|
for (int k = 0; k < (qk / (2 * blocklen)); k++) {
|
||||||
|
for (int j = 0; j < ncols_interleaved; j++) {
|
||||||
|
sumi = 0;
|
||||||
|
for (int i = 0; i < blocklen; ++i) {
|
||||||
|
const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F];
|
||||||
|
const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4];
|
||||||
|
sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2]));
|
||||||
|
}
|
||||||
|
sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||||
|
@ -934,6 +963,50 @@ void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||||
|
const int qk = QK8_0;
|
||||||
|
const int nb = n / qk;
|
||||||
|
const int ncols_interleaved = 8;
|
||||||
|
const int blocklen = 8;
|
||||||
|
|
||||||
|
assert(n % qk == 0);
|
||||||
|
assert(nr % 4 == 0);
|
||||||
|
assert(nc % ncols_interleaved == 0);
|
||||||
|
|
||||||
|
float sumf[4][8];
|
||||||
|
int sumi;
|
||||||
|
|
||||||
|
for (int y = 0; y < nr / 4; y++) {
|
||||||
|
const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
|
||||||
|
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||||
|
const block_iq4_nlx8 * b_ptr = (const block_iq4_nlx8 *) vx + (x * nb);
|
||||||
|
for (int m = 0; m < 4; m++) {
|
||||||
|
for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
|
||||||
|
}
|
||||||
|
for (int l = 0; l < nb; l++) {
|
||||||
|
for (int k = 0; k < (qk / (2 * blocklen)); k++) {
|
||||||
|
for (int m = 0; m < 4; m++) {
|
||||||
|
for (int j = 0; j < ncols_interleaved; j++) {
|
||||||
|
sumi = 0;
|
||||||
|
for (int i = 0; i < blocklen; ++i) {
|
||||||
|
const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F];
|
||||||
|
const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4];
|
||||||
|
sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
|
||||||
|
(v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4]));
|
||||||
|
}
|
||||||
|
sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (int m = 0; m < 4; m++) {
|
||||||
|
for (int j = 0; j < ncols_interleaved; j++)
|
||||||
|
s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // extern "C"
|
} // extern "C"
|
||||||
|
|
||||||
static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int blck_size_interleave) {
|
static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int blck_size_interleave) {
|
||||||
|
@ -1302,15 +1375,16 @@ static int repack_iq4_nl_to_iq4_nl_4_bl(struct ggml_tensor * t, int interleave_b
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
GGML_ASSERT(t->type == GGML_TYPE_IQ4_NL);
|
GGML_ASSERT(t->type == GGML_TYPE_IQ4_NL);
|
||||||
//GGML_ASSERT(interleave_block == 4 || interleave_block == 8);
|
|
||||||
GGML_ASSERT(interleave_block == 4);
|
GGML_ASSERT(interleave_block == 4);
|
||||||
|
|
||||||
block_iq4_nlx4 * dst = (block_iq4_nlx4 *)t->data;
|
|
||||||
const block_iq4_nl * src = (const block_iq4_nl *)data;
|
const block_iq4_nl * src = (const block_iq4_nl *)data;
|
||||||
|
block_iq4_nlx4 * dst = ( block_iq4_nlx4 *)t->data;
|
||||||
|
|
||||||
block_iq4_nl dst_tmp[4];
|
block_iq4_nl dst_tmp[4];
|
||||||
|
|
||||||
int nrow = ggml_nrows(t);
|
int nrow = ggml_nrows(t);
|
||||||
int nrows_interleaved = 4;
|
int nrows_interleaved = 4;
|
||||||
int nblocks = t->ne[0] / QK4_0;
|
int nblocks = t->ne[0] / QK4_NL;
|
||||||
|
|
||||||
GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_iq4_nl));
|
GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_iq4_nl));
|
||||||
|
|
||||||
|
@ -1332,6 +1406,63 @@ static int repack_iq4_nl_to_iq4_nl_4_bl(struct ggml_tensor * t, int interleave_b
|
||||||
GGML_UNUSED(data_size);
|
GGML_UNUSED(data_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static block_iq4_nlx8 make_block_iq4_nlx8(block_iq4_nl * in, unsigned int blck_size_interleave) {
|
||||||
|
block_iq4_nlx8 out;
|
||||||
|
|
||||||
|
for (int i = 0; i < 8; i++) {
|
||||||
|
out.d[i] = in[i].d;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int end = QK4_NL * 4 / blck_size_interleave;
|
||||||
|
|
||||||
|
if (blck_size_interleave == 8) {
|
||||||
|
for (int i = 0; i < end; ++i) {
|
||||||
|
int src_id = i % 8;
|
||||||
|
int src_offset = (i / 8) * blck_size_interleave;
|
||||||
|
int dst_offset = i * blck_size_interleave;
|
||||||
|
|
||||||
|
memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint64_t));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
GGML_ASSERT(false);
|
||||||
|
}
|
||||||
|
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int repack_iq4_nl_to_iq4_nl_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
|
||||||
|
GGML_ASSERT(t->type == GGML_TYPE_IQ4_NL);
|
||||||
|
GGML_ASSERT(interleave_block == 8);
|
||||||
|
|
||||||
|
const block_iq4_nl * src = (const block_iq4_nl *)data;
|
||||||
|
block_iq4_nlx8 * dst = ( block_iq4_nlx8 *)t->data;
|
||||||
|
|
||||||
|
block_iq4_nl dst_tmp[8];
|
||||||
|
|
||||||
|
int nrow = ggml_nrows(t);
|
||||||
|
int nrows_interleaved = 8;
|
||||||
|
int nblocks = t->ne[0] / QK4_NL;
|
||||||
|
|
||||||
|
GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_iq4_nl));
|
||||||
|
|
||||||
|
if (t->ne[1] % nrows_interleaved != 0) {
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int b = 0; b < nrow; b += nrows_interleaved) {
|
||||||
|
for (int64_t x = 0; x < nblocks; x++) {
|
||||||
|
for (int i = 0; i < nrows_interleaved; i++) {
|
||||||
|
dst_tmp[i] = src[x + i * nblocks];
|
||||||
|
}
|
||||||
|
*dst++ = make_block_iq4_nlx8(dst_tmp, interleave_block);
|
||||||
|
}
|
||||||
|
src += nrows_interleaved * nblocks;
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
|
||||||
|
GGML_UNUSED(data_size);
|
||||||
|
}
|
||||||
|
|
||||||
namespace ggml::cpu::repack {
|
namespace ggml::cpu::repack {
|
||||||
// repack
|
// repack
|
||||||
template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS>
|
template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS>
|
||||||
|
@ -1367,6 +1498,10 @@ template <> int repack<block_iq4_nl, 4, 4>(struct ggml_tensor * t, const void *
|
||||||
// return repack_iq4_nl_to_iq4_nl_4_bl(t, 8, data, data_size);
|
// return repack_iq4_nl_to_iq4_nl_4_bl(t, 8, data, data_size);
|
||||||
//}
|
//}
|
||||||
|
|
||||||
|
template <> int repack<block_iq4_nl, 8, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {
|
||||||
|
return repack_iq4_nl_to_iq4_nl_8_bl(t, 8, data, data_size);
|
||||||
|
}
|
||||||
|
|
||||||
// gemv
|
// gemv
|
||||||
template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PARAM_TYPE>
|
template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PARAM_TYPE>
|
||||||
void gemv(int, float *, size_t, const void *, const void *, int, int);
|
void gemv(int, float *, size_t, const void *, const void *, int, int);
|
||||||
|
@ -1395,6 +1530,10 @@ template <> void gemv<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size
|
||||||
ggml_gemv_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
|
ggml_gemv_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <> void gemv<block_iq4_nl, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||||
|
ggml_gemv_iq4_nl_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
|
||||||
|
}
|
||||||
|
|
||||||
// gemm
|
// gemm
|
||||||
template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PARAM_TYPE>
|
template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PARAM_TYPE>
|
||||||
void gemm(int, float *, size_t, const void *, const void *, int, int);
|
void gemm(int, float *, size_t, const void *, const void *, int, int);
|
||||||
|
@ -1423,6 +1562,10 @@ template <> void gemm<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size
|
||||||
ggml_gemm_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
|
ggml_gemm_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <> void gemm<block_iq4_nl, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||||
|
ggml_gemm_iq4_nl_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
|
||||||
|
}
|
||||||
|
|
||||||
class tensor_traits_base : public ggml::cpu::tensor_traits {
|
class tensor_traits_base : public ggml::cpu::tensor_traits {
|
||||||
public:
|
public:
|
||||||
virtual int repack(struct ggml_tensor * t, const void * data, size_t data_size) = 0;
|
virtual int repack(struct ggml_tensor * t, const void * data, size_t data_size) = 0;
|
||||||
|
@ -1706,6 +1849,7 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
|
||||||
|
|
||||||
// instance for IQ4
|
// instance for IQ4
|
||||||
static const ggml::cpu::repack::tensor_traits<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0> iq4_nl_4x4_q8_0;
|
static const ggml::cpu::repack::tensor_traits<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0> iq4_nl_4x4_q8_0;
|
||||||
|
static const ggml::cpu::repack::tensor_traits<block_iq4_nl, 8, 8, GGML_TYPE_Q8_0> iq4_nl_8x8_q8_0;
|
||||||
|
|
||||||
bool permit_repack = true;
|
bool permit_repack = true;
|
||||||
#if defined(GGML_USE_CLBLAST)
|
#if defined(GGML_USE_CLBLAST)
|
||||||
|
@ -1741,6 +1885,11 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if (cur->type == GGML_TYPE_IQ4_NL) {
|
} else if (cur->type == GGML_TYPE_IQ4_NL) {
|
||||||
|
if (ggml_cpu_has_avx2()) {
|
||||||
|
if (cur->ne[1] % 8 == 0) {
|
||||||
|
return &iq4_nl_8x8_q8_0;
|
||||||
|
}
|
||||||
|
}
|
||||||
if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
|
if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
|
||||||
if (cur->ne[1] % 4 == 0) {
|
if (cur->ne[1] % 4 == 0) {
|
||||||
return &iq4_nl_4x4_q8_0;
|
return &iq4_nl_4x4_q8_0;
|
||||||
|
|
|
@ -67,6 +67,13 @@ struct block_iq4_nlx4 {
|
||||||
|
|
||||||
static_assert(sizeof(block_iq4_nlx4) == 4 * sizeof(ggml_half) + QK4_NL * 2, "wrong iq4_nlx4 block size/padding");
|
static_assert(sizeof(block_iq4_nlx4) == 4 * sizeof(ggml_half) + QK4_NL * 2, "wrong iq4_nlx4 block size/padding");
|
||||||
|
|
||||||
|
struct block_iq4_nlx8 {
|
||||||
|
ggml_half d[8]; // deltas for 8 iq4_nl blocks
|
||||||
|
uint8_t qs[QK4_NL * 4]; // nibbles / quants for 8 iq4_nl blocks
|
||||||
|
};
|
||||||
|
|
||||||
|
static_assert(sizeof(block_iq4_nlx8) == 8 * sizeof(ggml_half) + QK4_NL * 4, "wrong iq4_nlx8 block size/padding");
|
||||||
|
|
||||||
#if defined(__cplusplus)
|
#if defined(__cplusplus)
|
||||||
extern "C" {
|
extern "C" {
|
||||||
#endif
|
#endif
|
||||||
|
@ -80,12 +87,14 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
||||||
void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
|
void ggml_gemv_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
|
void ggml_gemm_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
|
|
||||||
// Native implementations
|
// Native implementations
|
||||||
void ggml_quantize_mat_q8_0_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
void ggml_quantize_mat_q8_0_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||||
|
@ -97,12 +106,14 @@ void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs,
|
||||||
void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
|
void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
void ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
void ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
|
void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
|
|
||||||
#if defined(__cplusplus)
|
#if defined(__cplusplus)
|
||||||
} // extern "C"
|
} // extern "C"
|
||||||
|
|
|
@ -87,6 +87,10 @@
|
||||||
#define GGML_CUDA_CC_IS_QY2(cc) (cc >= GGML_CUDA_CC_QY2 && cc < GGML_CUDA_CC_NG)
|
#define GGML_CUDA_CC_IS_QY2(cc) (cc >= GGML_CUDA_CC_QY2 && cc < GGML_CUDA_CC_NG)
|
||||||
#define GGML_CUDA_CC_IS_NG(cc) (cc >= GGML_CUDA_CC_NG)
|
#define GGML_CUDA_CC_IS_NG(cc) (cc >= GGML_CUDA_CC_NG)
|
||||||
|
|
||||||
|
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
|
||||||
|
# define GGML_CUDA_USE_CUB
|
||||||
|
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
|
||||||
|
|
||||||
#ifdef __CUDA_ARCH_LIST__
|
#ifdef __CUDA_ARCH_LIST__
|
||||||
constexpr bool ggml_cuda_has_arch_impl(int) {
|
constexpr bool ggml_cuda_has_arch_impl(int) {
|
||||||
return false;
|
return false;
|
||||||
|
@ -424,26 +428,6 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
|
||||||
#endif // FP16_AVAILABLE
|
#endif // FP16_AVAILABLE
|
||||||
}
|
}
|
||||||
|
|
||||||
// Row reduction kernel template - compute sum (norm=false) or mean (norm=true)
|
|
||||||
template<bool norm>
|
|
||||||
static __global__ void reduce_rows_f32(const float * x, float * dst, const int ncols) {
|
|
||||||
const int row = blockIdx.x;
|
|
||||||
const int col = threadIdx.x;
|
|
||||||
|
|
||||||
float sum = 0.0f;
|
|
||||||
for (int i = col; i < ncols; i += blockDim.x) {
|
|
||||||
sum += x[row * ncols + i];
|
|
||||||
}
|
|
||||||
|
|
||||||
sum = warp_reduce_sum(sum);
|
|
||||||
|
|
||||||
if (col != 0) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
dst[row] = norm ? sum / ncols : sum;
|
|
||||||
}
|
|
||||||
|
|
||||||
template<int width = WARP_SIZE>
|
template<int width = WARP_SIZE>
|
||||||
static __device__ __forceinline__ int warp_reduce_all(int x) {
|
static __device__ __forceinline__ int warp_reduce_all(int x) {
|
||||||
#ifdef GGML_USE_HIP
|
#ifdef GGML_USE_HIP
|
||||||
|
@ -484,25 +468,21 @@ static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b
|
||||||
}
|
}
|
||||||
|
|
||||||
static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const half2 b) {
|
static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const half2 b) {
|
||||||
#if defined(GGML_USE_HIP) && HIP_VERSION >= 50700000
|
#if defined(GGML_USE_HIP)
|
||||||
return half2(__hmax(a.x, b.x), __hmax(a.y, b.y));
|
return half2(__hmax(a.x, b.x), __hmax(a.y, b.y));
|
||||||
#elif !defined(GGML_USE_HIP) && CUDART_VERSION >= CUDART_HMAX
|
#elif CUDART_VERSION >= CUDART_HMAX
|
||||||
return __hmax2(a, b);
|
return __hmax2(a, b);
|
||||||
#elif !defined(GGML_USE_HIP)
|
#else
|
||||||
half2 ret;
|
half2 ret;
|
||||||
reinterpret_cast<half&>(ret.x) = __float2half(fmaxf( __low2float(a), __low2float(b)));
|
reinterpret_cast<half&>(ret.x) = __float2half(fmaxf( __low2float(a), __low2float(b)));
|
||||||
reinterpret_cast<half&>(ret.y) = __float2half(fmaxf(__high2float(a), __high2float(b)));
|
reinterpret_cast<half&>(ret.y) = __float2half(fmaxf(__high2float(a), __high2float(b)));
|
||||||
return ret;
|
return ret;
|
||||||
#else
|
|
||||||
GGML_UNUSED(a);
|
|
||||||
GGML_UNUSED(b);
|
|
||||||
NO_DEVICE_CODE;
|
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
template<int width = WARP_SIZE>
|
template<int width = WARP_SIZE>
|
||||||
static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
|
static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
|
||||||
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL || (defined(GGML_USE_HIP) && HIP_VERSION >= 50700000)
|
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL || defined(GGML_USE_HIP)
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int offset = width/2; offset > 0; offset >>= 1) {
|
for (int offset = width/2; offset > 0; offset >>= 1) {
|
||||||
x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xffffffff, x, offset, width));
|
x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xffffffff, x, offset, width));
|
||||||
|
@ -511,7 +491,7 @@ static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
|
||||||
#else
|
#else
|
||||||
GGML_UNUSED(x);
|
GGML_UNUSED(x);
|
||||||
NO_DEVICE_CODE;
|
NO_DEVICE_CODE;
|
||||||
#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL || (defined(GGML_USE_HIP) && HIP_VERSION >= 50700000)
|
#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL || defined(GGML_USE_HIP)
|
||||||
}
|
}
|
||||||
|
|
||||||
#if CUDART_VERSION < CUDART_HMASK
|
#if CUDART_VERSION < CUDART_HMASK
|
||||||
|
|
|
@ -15,7 +15,6 @@ namespace wmma = mtmusa::wmma;
|
||||||
namespace wmma = nvcuda::wmma;
|
namespace wmma = nvcuda::wmma;
|
||||||
#endif // GGML_USE_MUSA
|
#endif // GGML_USE_MUSA
|
||||||
#elif defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)
|
#elif defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)
|
||||||
#undef HIP_ENABLE_WARP_SYNC_BUILTINS // conflicts with rocWMMA headers
|
|
||||||
#include <rocwmma/rocwmma.hpp>
|
#include <rocwmma/rocwmma.hpp>
|
||||||
namespace wmma = rocwmma;
|
namespace wmma = rocwmma;
|
||||||
#endif // !defined(GGML_USE_HIP)
|
#endif // !defined(GGML_USE_HIP)
|
||||||
|
|
|
@ -30,6 +30,7 @@ bool g_mul_mat_q = true;
|
||||||
#include "ggml-cuda/mmvq.cuh"
|
#include "ggml-cuda/mmvq.cuh"
|
||||||
#include "ggml-cuda/norm.cuh"
|
#include "ggml-cuda/norm.cuh"
|
||||||
#include "ggml-cuda/opt-step-adamw.cuh"
|
#include "ggml-cuda/opt-step-adamw.cuh"
|
||||||
|
#include "ggml-cuda/opt-step-sgd.cuh"
|
||||||
#include "ggml-cuda/out-prod.cuh"
|
#include "ggml-cuda/out-prod.cuh"
|
||||||
#include "ggml-cuda/pad.cuh"
|
#include "ggml-cuda/pad.cuh"
|
||||||
#include "ggml-cuda/pool2d.cuh"
|
#include "ggml-cuda/pool2d.cuh"
|
||||||
|
@ -182,30 +183,6 @@ static int ggml_cuda_parse_id(char devName[]) {
|
||||||
#endif // defined(GGML_USE_HIP)
|
#endif // defined(GGML_USE_HIP)
|
||||||
|
|
||||||
static ggml_cuda_device_info ggml_cuda_init() {
|
static ggml_cuda_device_info ggml_cuda_init() {
|
||||||
#if defined(GGML_USE_HIP)
|
|
||||||
// Workaround for a rocBLAS bug when using multiple graphics cards:
|
|
||||||
// https://github.com/ROCmSoftwarePlatform/rocBLAS/issues/1346
|
|
||||||
{
|
|
||||||
int major_version = 0;
|
|
||||||
size_t version_length = 0;
|
|
||||||
if (rocblas_get_version_string_size(&version_length) == rocblas_status_success) {
|
|
||||||
std::vector<char> version(version_length+1, '\0');
|
|
||||||
if (rocblas_get_version_string(version.data(), version.size()) == rocblas_status_success) {
|
|
||||||
version.resize(::strlen(version.data()));
|
|
||||||
int parsed_value = 0;
|
|
||||||
if (std::from_chars(version.data(), version.data() + version.size(), parsed_value).ec == std::errc()) {
|
|
||||||
major_version = parsed_value;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (major_version < 4) {
|
|
||||||
GGML_LOG_DEBUG(GGML_CUDA_NAME " calling rocblas_initialize as a workaround for a rocBLAS bug\n");
|
|
||||||
rocblas_initialize();
|
|
||||||
CUDA_CHECK(cudaDeviceSynchronize());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
ggml_cuda_device_info info = {};
|
ggml_cuda_device_info info = {};
|
||||||
|
|
||||||
cudaError_t err = cudaGetDeviceCount(&info.device_count);
|
cudaError_t err = cudaGetDeviceCount(&info.device_count);
|
||||||
|
@ -2516,6 +2493,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||||
case GGML_OP_OPT_STEP_ADAMW:
|
case GGML_OP_OPT_STEP_ADAMW:
|
||||||
ggml_cuda_opt_step_adamw(ctx, dst);
|
ggml_cuda_opt_step_adamw(ctx, dst);
|
||||||
break;
|
break;
|
||||||
|
case GGML_OP_OPT_STEP_SGD:
|
||||||
|
ggml_cuda_opt_step_sgd(ctx, dst);
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -3573,6 +3553,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||||
case GGML_OP_CROSS_ENTROPY_LOSS:
|
case GGML_OP_CROSS_ENTROPY_LOSS:
|
||||||
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
||||||
case GGML_OP_OPT_STEP_ADAMW:
|
case GGML_OP_OPT_STEP_ADAMW:
|
||||||
|
case GGML_OP_OPT_STEP_SGD:
|
||||||
return true;
|
return true;
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
|
|
|
@ -1,4 +1,14 @@
|
||||||
#include "mean.cuh"
|
#include "mean.cuh"
|
||||||
|
#include "reduce_rows.cuh"
|
||||||
|
|
||||||
|
#ifdef GGML_CUDA_USE_CUB
|
||||||
|
#include <cub/cub.cuh>
|
||||||
|
using namespace cub;
|
||||||
|
#endif // GGML_CUDA_USE_CUB
|
||||||
|
|
||||||
|
template <typename T> __global__ void divide_by_count(T * result, size_t count) {
|
||||||
|
*result /= static_cast<T>(count);
|
||||||
|
}
|
||||||
|
|
||||||
void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
const ggml_tensor * src0 = dst->src[0];
|
const ggml_tensor * src0 = dst->src[0];
|
||||||
|
@ -13,7 +23,51 @@ void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
const int64_t ncols = src0->ne[0];
|
const int64_t ncols = src0->ne[0];
|
||||||
const int64_t nrows = ggml_nrows(src0);
|
const int64_t nrows = ggml_nrows(src0);
|
||||||
|
|
||||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
// Special case for reducing vectors
|
||||||
const dim3 block_nums(nrows, 1, 1);
|
#ifdef GGML_CUDA_USE_CUB
|
||||||
reduce_rows_f32</*norm*/ true><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
|
#ifdef USE_CUDA_GRAPH
|
||||||
|
cudaStreamCaptureStatus iscapturing;
|
||||||
|
CUDA_CHECK(cudaStreamIsCapturing(stream, &iscapturing));
|
||||||
|
#endif // USE_CUDA_GRAPH
|
||||||
|
if ((nrows == 1) &&
|
||||||
|
#ifdef USE_CUDA_GRAPH
|
||||||
|
// CUDA_GRAPHS_DISABLED
|
||||||
|
((ncols > 65536) &&
|
||||||
|
((ctx.cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) ||
|
||||||
|
ctx.cuda_graph->disable_due_to_gpu_arch || ctx.cuda_graph->disable_due_to_too_many_updates ||
|
||||||
|
ctx.cuda_graph->disable_due_to_failed_graph_capture)) ||
|
||||||
|
// CUDA_GRAPHS ENABLED
|
||||||
|
((ncols > 32768) &&
|
||||||
|
!((ctx.cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) ||
|
||||||
|
ctx.cuda_graph->disable_due_to_gpu_arch || ctx.cuda_graph->disable_due_to_too_many_updates ||
|
||||||
|
ctx.cuda_graph->disable_due_to_failed_graph_capture))) {
|
||||||
|
#else
|
||||||
|
(ncols > 65536)) {
|
||||||
|
#endif // USE_CUDA_GRAPH
|
||||||
|
// Single row - use device-wide reduction
|
||||||
|
size_t tmp_size = 0;
|
||||||
|
ggml_cuda_pool & pool = ctx.pool();
|
||||||
|
|
||||||
|
DeviceReduce::Sum(nullptr, tmp_size, src0_d, dst_d, ncols, stream);
|
||||||
|
|
||||||
|
ggml_cuda_pool_alloc<uint8_t> tmp_alloc(pool, tmp_size);
|
||||||
|
DeviceReduce::Sum(tmp_alloc.ptr, tmp_size, src0_d, dst_d, ncols, stream);
|
||||||
|
|
||||||
|
// Divide by ncols
|
||||||
|
divide_by_count<float><<<1, 1, 0, stream>>>(dst_d, ncols);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
#endif // GGML_CUDA_USE_CUB
|
||||||
|
|
||||||
|
const dim3 block_nums(nrows, 1, 1);
|
||||||
|
|
||||||
|
const int id = ggml_cuda_get_device();
|
||||||
|
const int nsm = ggml_cuda_info().devices[id].nsm;
|
||||||
|
if ((nrows / nsm) < 2) {
|
||||||
|
const dim3 block_dims(512, 1, 1);
|
||||||
|
reduce_rows_f32</*norm=*/true><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
|
||||||
|
} else {
|
||||||
|
const dim3 block_dims(ncols < 1024 ? 32 : 128, 1, 1);
|
||||||
|
reduce_rows_f32</*norm=*/true><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
49
ggml/src/ggml-cuda/opt-step-sgd.cu
Normal file
49
ggml/src/ggml-cuda/opt-step-sgd.cu
Normal file
|
@ -0,0 +1,49 @@
|
||||||
|
#include "ggml-impl.h"
|
||||||
|
#include "opt-step-sgd.cuh"
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
|
||||||
|
static __global__ void opt_step_sgd_f32(
|
||||||
|
float * __restrict__ x, const float * __restrict__ g,
|
||||||
|
const float * __restrict__ pars, const int64_t k) {
|
||||||
|
|
||||||
|
const int64_t i = (int64_t) blockIdx.x*blockDim.x + threadIdx.x;
|
||||||
|
|
||||||
|
if (i >= k) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
x[i] = x[i] * (1.0f - pars[0] * pars[1]) - pars[0] * g[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
static void opt_step_sgd_f32_cuda(
|
||||||
|
float * x, const float * g, const float * __restrict__ pars, const int64_t k, cudaStream_t stream) {
|
||||||
|
|
||||||
|
const dim3 block_dims(CUDA_OPT_STEP_SGD_BLOCK_SIZE, 1, 1);
|
||||||
|
const dim3 block_nums((k + CUDA_OPT_STEP_SGD_BLOCK_SIZE - 1) / CUDA_OPT_STEP_SGD_BLOCK_SIZE, 1, 1);
|
||||||
|
opt_step_sgd_f32<<<block_nums, block_dims, 0, stream>>>(x, g, pars, k);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_cuda_opt_step_sgd(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
|
const ggml_tensor * src0 = dst->src[0];
|
||||||
|
const ggml_tensor * src0_grad = dst->src[1];
|
||||||
|
const ggml_tensor * params = dst->src[2];
|
||||||
|
|
||||||
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(src0_grad->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(params->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(src0_grad));
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(params));
|
||||||
|
GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
|
||||||
|
GGML_ASSERT(ggml_nelements(params) == 2);
|
||||||
|
|
||||||
|
float * src0_d = (float *) src0->data;
|
||||||
|
const float * src0_grad_d = (const float *) src0_grad->data;
|
||||||
|
const float * params_d = (const float *) params->data;
|
||||||
|
|
||||||
|
cudaStream_t stream = ctx.stream();
|
||||||
|
|
||||||
|
const int64_t ne = ggml_nelements(src0);
|
||||||
|
|
||||||
|
opt_step_sgd_f32_cuda(src0_d, src0_grad_d, params_d, ne, stream);
|
||||||
|
}
|
5
ggml/src/ggml-cuda/opt-step-sgd.cuh
Normal file
5
ggml/src/ggml-cuda/opt-step-sgd.cuh
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
#include "common.cuh"
|
||||||
|
|
||||||
|
#define CUDA_OPT_STEP_SGD_BLOCK_SIZE 256
|
||||||
|
|
||||||
|
void ggml_cuda_opt_step_sgd(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
53
ggml/src/ggml-cuda/reduce_rows.cuh
Normal file
53
ggml/src/ggml-cuda/reduce_rows.cuh
Normal file
|
@ -0,0 +1,53 @@
|
||||||
|
#include "common.cuh"
|
||||||
|
|
||||||
|
// Row reduction kernel template - compute sum (norm=false) or mean (norm=true)
|
||||||
|
template <bool norm>
|
||||||
|
static __global__ void reduce_rows_f32(const float * __restrict__ x, float * __restrict__ dst, const int ncols) {
|
||||||
|
const int row = blockIdx.x;
|
||||||
|
const int col = threadIdx.x;
|
||||||
|
|
||||||
|
float sum = 0.0f;
|
||||||
|
const int num_unroll = 8;
|
||||||
|
float temp[num_unroll];
|
||||||
|
float sum_temp[num_unroll] = { 0.0f };
|
||||||
|
for (int i = col; i < ncols;) {
|
||||||
|
for (int j = 0; j < num_unroll; ++j) {
|
||||||
|
if (i < ncols) {
|
||||||
|
temp[j] = x[row * ncols + i];
|
||||||
|
} else {
|
||||||
|
temp[j] = 0;
|
||||||
|
}
|
||||||
|
i += blockDim.x;
|
||||||
|
}
|
||||||
|
for (int j = 0; j < num_unroll; ++j) {
|
||||||
|
sum_temp[j] += temp[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (int j = 0; j < num_unroll; ++j) {
|
||||||
|
sum += sum_temp[j];
|
||||||
|
}
|
||||||
|
|
||||||
|
// sum up partial sums
|
||||||
|
sum = warp_reduce_sum(sum);
|
||||||
|
if (blockDim.x > WARP_SIZE) {
|
||||||
|
assert((blockDim.x <= 1024) && (blockDim.x % WARP_SIZE) == 0);
|
||||||
|
__shared__ float s_sum[32];
|
||||||
|
const int warp_id = threadIdx.x / WARP_SIZE;
|
||||||
|
const int lane_id = threadIdx.x % WARP_SIZE;
|
||||||
|
if (lane_id == 0) {
|
||||||
|
s_sum[warp_id] = sum;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
sum = 0.0f;
|
||||||
|
if (lane_id < (blockDim.x / WARP_SIZE)) {
|
||||||
|
sum = s_sum[lane_id];
|
||||||
|
}
|
||||||
|
sum = warp_reduce_sum(sum);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (col != 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
dst[row] = norm ? sum / ncols : sum;
|
||||||
|
}
|
|
@ -1,19 +1,15 @@
|
||||||
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
|
#include "sum.cuh"
|
||||||
#define USE_CUB
|
#include "sumrows.cuh"
|
||||||
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
|
|
||||||
|
|
||||||
#ifdef USE_CUB
|
#ifdef GGML_CUDA_USE_CUB
|
||||||
#include <cub/cub.cuh>
|
#include <cub/cub.cuh>
|
||||||
using namespace cub;
|
using namespace cub;
|
||||||
#endif // USE_CUB
|
#endif // GGML_CUDA_USE_CUB
|
||||||
|
|
||||||
#include "sumrows.cuh"
|
|
||||||
#include "sum.cuh"
|
|
||||||
|
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
|
|
||||||
void sum_f32_cuda(ggml_cuda_pool & pool, const float * x, float * dst, const int64_t ne, cudaStream_t stream) {
|
void sum_f32_cuda(ggml_cuda_pool & pool, const float * x, float * dst, const int64_t ne, cudaStream_t stream) {
|
||||||
#ifdef USE_CUB
|
#ifdef GGML_CUDA_USE_CUB
|
||||||
size_t tmp_size = 0;
|
size_t tmp_size = 0;
|
||||||
DeviceReduce::Sum(nullptr, tmp_size, x, dst, ne, stream);
|
DeviceReduce::Sum(nullptr, tmp_size, x, dst, ne, stream);
|
||||||
ggml_cuda_pool_alloc<uint8_t> tmp_alloc(pool, tmp_size);
|
ggml_cuda_pool_alloc<uint8_t> tmp_alloc(pool, tmp_size);
|
||||||
|
@ -23,7 +19,7 @@ void sum_f32_cuda(ggml_cuda_pool & pool, const float * x, float * dst, const int
|
||||||
// For AMD there is rocPRIM which could be used as a drop-in replacement via hipcub but this would require C++11 -> C++14.
|
// For AMD there is rocPRIM which could be used as a drop-in replacement via hipcub but this would require C++11 -> C++14.
|
||||||
sum_rows_f32_cuda(x, dst, ne, 1, stream);
|
sum_rows_f32_cuda(x, dst, ne, 1, stream);
|
||||||
GGML_UNUSED(pool);
|
GGML_UNUSED(pool);
|
||||||
#endif // USE_CUB
|
#endif // GGML_CUDA_USE_CUB
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_cuda_op_sum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
void ggml_cuda_op_sum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
|
|
|
@ -1,9 +1,17 @@
|
||||||
|
#include "reduce_rows.cuh"
|
||||||
#include "sumrows.cuh"
|
#include "sumrows.cuh"
|
||||||
|
|
||||||
void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
const int id = ggml_cuda_get_device();
|
||||||
|
const int nsm = ggml_cuda_info().devices[id].nsm;
|
||||||
const dim3 block_nums(nrows, 1, 1);
|
const dim3 block_nums(nrows, 1, 1);
|
||||||
reduce_rows_f32</*norm*/false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
|
if ((nrows / nsm) < 2) {
|
||||||
|
const dim3 block_dims(512, 1, 1);
|
||||||
|
reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
|
||||||
|
} else {
|
||||||
|
const dim3 block_dims(ncols < 1024 ? 32 : 128, 1, 1);
|
||||||
|
reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
|
@ -19,8 +27,17 @@ void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
const int64_t ncols = src0->ne[0];
|
const int64_t ncols = src0->ne[0];
|
||||||
const int64_t nrows = ggml_nrows(src0);
|
const int64_t nrows = ggml_nrows(src0);
|
||||||
|
|
||||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
|
||||||
const dim3 block_nums(nrows, 1, 1);
|
const dim3 block_nums(nrows, 1, 1);
|
||||||
|
|
||||||
|
const int id = ggml_cuda_get_device();
|
||||||
|
const int nsm = ggml_cuda_info().devices[id].nsm;
|
||||||
|
if ((nrows / nsm) < 2) {
|
||||||
|
// Increase num threads to 512 for small nrows to better hide the latency
|
||||||
|
const dim3 block_dims(512, 1, 1);
|
||||||
|
reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
|
||||||
|
} else {
|
||||||
|
// Enough active SMs to hide latency, use smaller blocks to allow better scheduling
|
||||||
|
const dim3 block_dims(ncols < 1024 ? 32 : 128, 1, 1);
|
||||||
reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
|
reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
18
ggml/src/ggml-cuda/vendors/hip.h
vendored
18
ggml/src/ggml-cuda/vendors/hip.h
vendored
|
@ -1,12 +1,10 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#define HIP_ENABLE_WARP_SYNC_BUILTINS 1
|
#define HIP_DISABLE_WARP_SYNC_BUILTINS 1
|
||||||
#include <hip/hip_runtime.h>
|
#include <hip/hip_runtime.h>
|
||||||
#include <hipblas/hipblas.h>
|
#include <hipblas/hipblas.h>
|
||||||
#include <hip/hip_fp16.h>
|
#include <hip/hip_fp16.h>
|
||||||
#include <hip/hip_bfloat16.h>
|
#include <hip/hip_bfloat16.h>
|
||||||
// for rocblas_initialize()
|
|
||||||
#include "rocblas/rocblas.h"
|
|
||||||
|
|
||||||
#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
|
#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
|
||||||
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT
|
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT
|
||||||
|
@ -251,17 +249,3 @@ static __device__ __forceinline__ unsigned int __vcmpne4(unsigned int a, unsigne
|
||||||
}
|
}
|
||||||
return c;
|
return c;
|
||||||
}
|
}
|
||||||
|
|
||||||
#if HIP_VERSION < 50600000
|
|
||||||
// __shfl_xor() for half2 was added in ROCm 5.6
|
|
||||||
static __device__ __forceinline__ half2 __shfl_xor(half2 var, int laneMask, int width) {
|
|
||||||
typedef union half2_b32 {
|
|
||||||
half2 val;
|
|
||||||
int b32;
|
|
||||||
} half2_b32_t;
|
|
||||||
half2_b32_t tmp;
|
|
||||||
tmp.val = var;
|
|
||||||
tmp.b32 = __shfl_xor(tmp.b32, laneMask, width);
|
|
||||||
return tmp.val;
|
|
||||||
}
|
|
||||||
#endif // HIP_VERSION < 50600000
|
|
||||||
|
|
|
@ -66,7 +66,9 @@ struct ggml_opt_context {
|
||||||
|
|
||||||
ggml_opt_get_optimizer_params get_opt_pars = nullptr;
|
ggml_opt_get_optimizer_params get_opt_pars = nullptr;
|
||||||
void * get_opt_pars_ud = nullptr;
|
void * get_opt_pars_ud = nullptr;
|
||||||
struct ggml_tensor * adamw_params = nullptr;
|
struct ggml_tensor * opt_step_params = nullptr; // Stores output of get_opt_pars.
|
||||||
|
|
||||||
|
enum ggml_opt_optimizer_type optimizer = GGML_OPT_OPTIMIZER_TYPE_ADAMW;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ggml_opt_result {
|
struct ggml_opt_result {
|
||||||
|
@ -229,9 +231,13 @@ struct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params(void * us
|
||||||
result.adamw.eps = 1e-8f;
|
result.adamw.eps = 1e-8f;
|
||||||
result.adamw.wd = 0.0f;
|
result.adamw.wd = 0.0f;
|
||||||
|
|
||||||
|
result.sgd.alpha = 1e-3f;
|
||||||
|
result.sgd.wd = 0.0f;
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
struct ggml_opt_optimizer_params ggml_opt_get_constant_optimizer_params(void * userdata) {
|
struct ggml_opt_optimizer_params ggml_opt_get_constant_optimizer_params(void * userdata) {
|
||||||
return *((struct ggml_opt_optimizer_params *) userdata);
|
return *((struct ggml_opt_optimizer_params *) userdata);
|
||||||
}
|
}
|
||||||
|
@ -249,6 +255,7 @@ struct ggml_opt_params ggml_opt_default_params(
|
||||||
/*opt_period =*/ 1,
|
/*opt_period =*/ 1,
|
||||||
/*get_opt_pars =*/ ggml_opt_get_default_optimizer_params,
|
/*get_opt_pars =*/ ggml_opt_get_default_optimizer_params,
|
||||||
/*get_opt_pars_ud =*/ nullptr,
|
/*get_opt_pars_ud =*/ nullptr,
|
||||||
|
/*optimizer =*/ GGML_OPT_OPTIMIZER_TYPE_ADAMW,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -316,9 +323,14 @@ static void ggml_opt_build(ggml_opt_context_t opt_ctx) {
|
||||||
GGML_ASSERT(opt_ctx->ctx_compute && "no compute context set, either use static graphs or set one with ggml_opt_prepare_alloc");
|
GGML_ASSERT(opt_ctx->ctx_compute && "no compute context set, either use static graphs or set one with ggml_opt_prepare_alloc");
|
||||||
GGML_ASSERT((!opt_ctx->static_graphs || opt_ctx->inputs->data) && "when using static graphs the inputs must be allocated statically");
|
GGML_ASSERT((!opt_ctx->static_graphs || opt_ctx->inputs->data) && "when using static graphs the inputs must be allocated statically");
|
||||||
|
|
||||||
|
const enum ggml_opt_optimizer_type optimizer = opt_ctx->optimizer;
|
||||||
|
|
||||||
const bool accumulate = opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_GRAD &&
|
const bool accumulate = opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_GRAD &&
|
||||||
!(opt_ctx->static_graphs && opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT && opt_ctx->opt_period == 1);
|
!(opt_ctx->static_graphs && opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT && opt_ctx->opt_period == 1);
|
||||||
|
|
||||||
|
const bool need_momenta = opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT &&
|
||||||
|
opt_ctx->optimizer == GGML_OPT_OPTIMIZER_TYPE_ADAMW;
|
||||||
|
|
||||||
ggml_set_input(opt_ctx->inputs);
|
ggml_set_input(opt_ctx->inputs);
|
||||||
ggml_set_output(opt_ctx->outputs);
|
ggml_set_output(opt_ctx->outputs);
|
||||||
|
|
||||||
|
@ -340,8 +352,7 @@ static void ggml_opt_build(ggml_opt_context_t opt_ctx) {
|
||||||
// - pred (if using static graphs)
|
// - pred (if using static graphs)
|
||||||
// - ncorrect (if using static graphs, 2 tensors).
|
// - ncorrect (if using static graphs, 2 tensors).
|
||||||
constexpr size_t n_loss = 1;
|
constexpr size_t n_loss = 1;
|
||||||
const size_t tensors_per_param = (accumulate ? 1 : 0) +
|
const size_t tensors_per_param = (accumulate ? 1 : 0) + (need_momenta ? 2 : 0);
|
||||||
(opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT ? 2 : 0);
|
|
||||||
const size_t tensors_const = opt_ctx->static_graphs ? 9 : 0;
|
const size_t tensors_const = opt_ctx->static_graphs ? 9 : 0;
|
||||||
const size_t size_meta = (n_loss + tensors_per_param*n_param + tensors_const) * ggml_tensor_overhead();
|
const size_t size_meta = (n_loss + tensors_per_param*n_param + tensors_const) * ggml_tensor_overhead();
|
||||||
struct ggml_init_params params = {
|
struct ggml_init_params params = {
|
||||||
|
@ -458,7 +469,7 @@ static void ggml_opt_build(ggml_opt_context_t opt_ctx) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_OPT) {
|
if (need_momenta && opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_OPT) {
|
||||||
opt_ctx->grad_m.resize(n_nodes);
|
opt_ctx->grad_m.resize(n_nodes);
|
||||||
opt_ctx->grad_v.resize(n_nodes);
|
opt_ctx->grad_v.resize(n_nodes);
|
||||||
for (int i = 0; i < n_nodes; ++i) {
|
for (int i = 0; i < n_nodes; ++i) {
|
||||||
|
@ -492,23 +503,36 @@ static void ggml_opt_build(ggml_opt_context_t opt_ctx) {
|
||||||
// gb_opt == graph backward optimize, forward pass, then backward pass to calculate gradients, then optimizer step.
|
// gb_opt == graph backward optimize, forward pass, then backward pass to calculate gradients, then optimizer step.
|
||||||
opt_ctx->gb_opt = ggml_graph_dup(opt_ctx->ctx_compute, opt_ctx->gb_grad, /*force_grads =*/ true);
|
opt_ctx->gb_opt = ggml_graph_dup(opt_ctx->ctx_compute, opt_ctx->gb_grad, /*force_grads =*/ true);
|
||||||
|
|
||||||
opt_ctx->adamw_params = ggml_new_tensor_1d(opt_ctx->ctx_cpu, GGML_TYPE_F32, 7);
|
opt_ctx->opt_step_params = ggml_new_tensor_1d(opt_ctx->ctx_cpu, GGML_TYPE_F32, need_momenta ? 7 : 2);
|
||||||
ggml_set_input(opt_ctx->adamw_params);
|
ggml_tensor * adamw_params = opt_ctx->opt_step_params;
|
||||||
ggml_set_name(opt_ctx->adamw_params, "adamw_params");
|
ggml_set_input(adamw_params);
|
||||||
|
const char * optimizer_name = ggml_opt_optimizer_name(opt_ctx->optimizer);
|
||||||
|
ggml_format_name(adamw_params, "%s_params", optimizer_name);
|
||||||
for (int i = opt_ctx->gf->n_nodes-1; i >= 0; --i) {
|
for (int i = opt_ctx->gf->n_nodes-1; i >= 0; --i) {
|
||||||
struct ggml_tensor * node = opt_ctx->gb_opt->nodes[i];
|
struct ggml_tensor * node = opt_ctx->gb_opt->nodes[i];
|
||||||
struct ggml_tensor * grad = ggml_graph_get_grad(opt_ctx->gb_opt, node);
|
struct ggml_tensor * grad = ggml_graph_get_grad(opt_ctx->gb_opt, node);
|
||||||
|
|
||||||
if (grad && (node->flags & GGML_TENSOR_FLAG_PARAM)) {
|
if (grad && (node->flags & GGML_TENSOR_FLAG_PARAM)) {
|
||||||
struct ggml_tensor * m = opt_ctx->grad_m[i];
|
struct ggml_tensor * m = nullptr;
|
||||||
struct ggml_tensor * v = opt_ctx->grad_v[i];
|
struct ggml_tensor * v = nullptr;
|
||||||
struct ggml_tensor * opt_step = ggml_opt_step_adamw(opt_ctx->ctx_compute, node, grad, m, v, opt_ctx->adamw_params);
|
if (need_momenta) {
|
||||||
|
m = opt_ctx->grad_m[i];
|
||||||
ggml_set_name(m, (std::string("AdamW m for ") + std::string(node->name)).c_str());
|
v = opt_ctx->grad_v[i];
|
||||||
ggml_set_name(v, (std::string("AdamW v for ") + std::string(node->name)).c_str());
|
ggml_format_name(m, "AdamW m for %s", node->name);
|
||||||
ggml_set_name(opt_step, (std::string("AdamW step for ") + std::string(node->name)).c_str());
|
ggml_format_name(v, "AdamW v for %s", node->name);
|
||||||
|
}
|
||||||
|
struct ggml_tensor * opt_step;
|
||||||
|
switch (optimizer) {
|
||||||
|
case GGML_OPT_OPTIMIZER_TYPE_ADAMW:
|
||||||
|
opt_step = ggml_opt_step_adamw(opt_ctx->ctx_compute, node, grad, m, v, adamw_params);
|
||||||
|
break;
|
||||||
|
case GGML_OPT_OPTIMIZER_TYPE_SGD:
|
||||||
|
opt_step = ggml_opt_step_sgd(opt_ctx->ctx_compute, node, grad, adamw_params);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
GGML_ABORT("fatal error");
|
||||||
|
}
|
||||||
|
ggml_format_name(opt_step, "%s step for %s", optimizer_name, node->name);
|
||||||
ggml_build_forward_expand(opt_ctx->gb_opt, opt_step);
|
ggml_build_forward_expand(opt_ctx->gb_opt, opt_step);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -534,6 +558,7 @@ ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params) {
|
||||||
result->opt_period = params.opt_period;
|
result->opt_period = params.opt_period;
|
||||||
result->get_opt_pars = params.get_opt_pars;
|
result->get_opt_pars = params.get_opt_pars;
|
||||||
result->get_opt_pars_ud = params.get_opt_pars_ud;
|
result->get_opt_pars_ud = params.get_opt_pars_ud;
|
||||||
|
result->optimizer = params.optimizer;
|
||||||
|
|
||||||
GGML_ASSERT(result->opt_period >= 1);
|
GGML_ASSERT(result->opt_period >= 1);
|
||||||
|
|
||||||
|
@ -756,8 +781,10 @@ void ggml_opt_alloc(ggml_opt_context_t opt_ctx, bool backward) {
|
||||||
void ggml_opt_eval(ggml_opt_context_t opt_ctx, ggml_opt_result_t result) {
|
void ggml_opt_eval(ggml_opt_context_t opt_ctx, ggml_opt_result_t result) {
|
||||||
GGML_ASSERT(opt_ctx->eval_ready);
|
GGML_ASSERT(opt_ctx->eval_ready);
|
||||||
if (opt_ctx->allocated_graph == opt_ctx->gb_opt) {
|
if (opt_ctx->allocated_graph == opt_ctx->gb_opt) {
|
||||||
struct ggml_opt_optimizer_params opt_pars = opt_ctx->get_opt_pars(opt_ctx->get_opt_pars_ud);
|
const ggml_opt_optimizer_params & opt_pars = opt_ctx->get_opt_pars(opt_ctx->get_opt_pars_ud);
|
||||||
|
|
||||||
|
switch (opt_ctx->optimizer) {
|
||||||
|
case GGML_OPT_OPTIMIZER_TYPE_ADAMW: {
|
||||||
GGML_ASSERT(opt_pars.adamw.alpha > 0.0f);
|
GGML_ASSERT(opt_pars.adamw.alpha > 0.0f);
|
||||||
GGML_ASSERT(opt_pars.adamw.beta1 >= 0.0f);
|
GGML_ASSERT(opt_pars.adamw.beta1 >= 0.0f);
|
||||||
GGML_ASSERT(opt_pars.adamw.beta1 <= 1.0f);
|
GGML_ASSERT(opt_pars.adamw.beta1 <= 1.0f);
|
||||||
|
@ -771,7 +798,7 @@ void ggml_opt_eval(ggml_opt_context_t opt_ctx, ggml_opt_result_t result) {
|
||||||
const float beta1h = 1.0f / (1.0f - powf(opt_pars.adamw.beta1, opt_ctx->iter));
|
const float beta1h = 1.0f / (1.0f - powf(opt_pars.adamw.beta1, opt_ctx->iter));
|
||||||
const float beta2h = 1.0f / (1.0f - powf(opt_pars.adamw.beta2, opt_ctx->iter));
|
const float beta2h = 1.0f / (1.0f - powf(opt_pars.adamw.beta2, opt_ctx->iter));
|
||||||
|
|
||||||
float * adamw_par_data = ggml_get_data_f32(opt_ctx->adamw_params);
|
float * adamw_par_data = ggml_get_data_f32(opt_ctx->opt_step_params);
|
||||||
adamw_par_data[0] = opt_pars.adamw.alpha;
|
adamw_par_data[0] = opt_pars.adamw.alpha;
|
||||||
adamw_par_data[1] = opt_pars.adamw.beta1;
|
adamw_par_data[1] = opt_pars.adamw.beta1;
|
||||||
adamw_par_data[2] = opt_pars.adamw.beta2;
|
adamw_par_data[2] = opt_pars.adamw.beta2;
|
||||||
|
@ -779,6 +806,18 @@ void ggml_opt_eval(ggml_opt_context_t opt_ctx, ggml_opt_result_t result) {
|
||||||
adamw_par_data[4] = opt_pars.adamw.wd;
|
adamw_par_data[4] = opt_pars.adamw.wd;
|
||||||
adamw_par_data[5] = beta1h;
|
adamw_par_data[5] = beta1h;
|
||||||
adamw_par_data[6] = beta2h;
|
adamw_par_data[6] = beta2h;
|
||||||
|
} break;
|
||||||
|
case GGML_OPT_OPTIMIZER_TYPE_SGD: {
|
||||||
|
GGML_ASSERT(opt_pars.sgd.alpha > 0.0f);
|
||||||
|
GGML_ASSERT(opt_pars.sgd.wd >= 0.0f);
|
||||||
|
GGML_ASSERT(opt_pars.sgd.wd <= 1.0f);
|
||||||
|
float * sgd = ggml_get_data_f32(opt_ctx->opt_step_params);
|
||||||
|
sgd[0] = opt_pars.sgd.alpha;
|
||||||
|
sgd[1] = opt_pars.sgd.wd;
|
||||||
|
} break;
|
||||||
|
default:
|
||||||
|
GGML_ABORT("fatal error");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_backend_sched_graph_compute(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy);
|
ggml_backend_sched_graph_compute(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy);
|
||||||
|
@ -963,6 +1002,7 @@ void ggml_opt_fit(
|
||||||
ggml_tensor * outputs,
|
ggml_tensor * outputs,
|
||||||
ggml_opt_dataset_t dataset,
|
ggml_opt_dataset_t dataset,
|
||||||
enum ggml_opt_loss_type loss_type,
|
enum ggml_opt_loss_type loss_type,
|
||||||
|
enum ggml_opt_optimizer_type optimizer,
|
||||||
ggml_opt_get_optimizer_params get_opt_pars,
|
ggml_opt_get_optimizer_params get_opt_pars,
|
||||||
int64_t nepoch,
|
int64_t nepoch,
|
||||||
int64_t nbatch_logical,
|
int64_t nbatch_logical,
|
||||||
|
@ -993,6 +1033,7 @@ void ggml_opt_fit(
|
||||||
params.opt_period = opt_period;
|
params.opt_period = opt_period;
|
||||||
params.get_opt_pars = get_opt_pars;
|
params.get_opt_pars = get_opt_pars;
|
||||||
params.get_opt_pars_ud = &epoch;
|
params.get_opt_pars_ud = &epoch;
|
||||||
|
params.optimizer = optimizer;
|
||||||
ggml_opt_context_t opt_ctx = ggml_opt_init(params);
|
ggml_opt_context_t opt_ctx = ggml_opt_init(params);
|
||||||
|
|
||||||
// Shuffling the data is generally useful but there is only a point if not all data is used in a single batch.
|
// Shuffling the data is generally useful but there is only a point if not all data is used in a single batch.
|
||||||
|
@ -1035,3 +1076,18 @@ void ggml_opt_fit(
|
||||||
ggml_opt_result_free(result_train);
|
ggml_opt_result_free(result_train);
|
||||||
ggml_opt_result_free(result_val);
|
ggml_opt_result_free(result_val);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
enum ggml_opt_optimizer_type ggml_opt_context_optimizer_type(ggml_opt_context_t c) {
|
||||||
|
return c->optimizer;
|
||||||
|
}
|
||||||
|
|
||||||
|
GGML_API const char * ggml_opt_optimizer_name(enum ggml_opt_optimizer_type o) {
|
||||||
|
switch (o) {
|
||||||
|
case GGML_OPT_OPTIMIZER_TYPE_ADAMW:
|
||||||
|
return "adamw";
|
||||||
|
case GGML_OPT_OPTIMIZER_TYPE_SGD:
|
||||||
|
return "sgd";
|
||||||
|
default:
|
||||||
|
return "undefined";
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
|
@ -526,6 +526,7 @@ struct vk_device_struct {
|
||||||
vk_pipeline pipeline_rwkv_wkv6_f32;
|
vk_pipeline pipeline_rwkv_wkv6_f32;
|
||||||
vk_pipeline pipeline_rwkv_wkv7_f32;
|
vk_pipeline pipeline_rwkv_wkv7_f32;
|
||||||
vk_pipeline pipeline_opt_step_adamw_f32;
|
vk_pipeline pipeline_opt_step_adamw_f32;
|
||||||
|
vk_pipeline pipeline_opt_step_sgd_f32;
|
||||||
vk_pipeline pipeline_conv2d_f32[CONV_SHAPE_COUNT];
|
vk_pipeline pipeline_conv2d_f32[CONV_SHAPE_COUNT];
|
||||||
vk_pipeline pipeline_conv2d_f16_f32[CONV_SHAPE_COUNT];
|
vk_pipeline pipeline_conv2d_f16_f32[CONV_SHAPE_COUNT];
|
||||||
vk_pipeline pipeline_conv2d_dw_whcn_f32;
|
vk_pipeline pipeline_conv2d_dw_whcn_f32;
|
||||||
|
@ -3139,6 +3140,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
||||||
|
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_opt_step_sgd_f32, "opt_step_sgd_f32", opt_step_sgd_f32_len, opt_step_sgd_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
||||||
|
|
||||||
// conv2d
|
// conv2d
|
||||||
for (uint32_t s = 0; s < CONV_SHAPE_COUNT; ++s) {
|
for (uint32_t s = 0; s < CONV_SHAPE_COUNT; ++s) {
|
||||||
uint32_t conv2d_WG_SIZE = 256;
|
uint32_t conv2d_WG_SIZE = 256;
|
||||||
|
@ -7223,6 +7226,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||||
return ctx->device->pipeline_opt_step_adamw_f32;
|
return ctx->device->pipeline_opt_step_adamw_f32;
|
||||||
}
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
case GGML_OP_OPT_STEP_SGD:
|
||||||
|
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||||
|
return ctx->device->pipeline_opt_step_sgd_f32;
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
case GGML_OP_LEAKY_RELU:
|
case GGML_OP_LEAKY_RELU:
|
||||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||||
return ctx->device->pipeline_leaky_relu_f32;
|
return ctx->device->pipeline_leaky_relu_f32;
|
||||||
|
@ -7722,6 +7730,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
||||||
ggml_vk_buffer_memset_async(subctx, d_D, d_buf_offset, 0, d_sz);
|
ggml_vk_buffer_memset_async(subctx, d_D, d_buf_offset, 0, d_sz);
|
||||||
ggml_vk_sync_buffers(subctx);
|
ggml_vk_sync_buffers(subctx);
|
||||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
|
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
|
||||||
|
} else if (op == GGML_OP_OPT_STEP_SGD) {
|
||||||
|
// OPT_STEP_SGD works on src0, it does not need dst
|
||||||
|
ggml_vk_sync_buffers(subctx);
|
||||||
|
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz } }, pc, elements);
|
||||||
} else if (use_src2) {
|
} else if (use_src2) {
|
||||||
ggml_vk_sync_buffers(subctx);
|
ggml_vk_sync_buffers(subctx);
|
||||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
|
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
|
||||||
|
@ -8075,6 +8087,12 @@ static void ggml_vk_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& su
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void ggml_vk_opt_step_sgd(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) {
|
||||||
|
const size_t n = ggml_nelements(dst->src[0]);
|
||||||
|
|
||||||
|
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_OPT_STEP_SGD, { (uint32_t)n, 0, 0.0f, 0.0f }, dryrun);
|
||||||
|
}
|
||||||
|
|
||||||
static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
||||||
int * op_params = (int *)dst->op_params;
|
int * op_params = (int *)dst->op_params;
|
||||||
|
|
||||||
|
@ -9628,6 +9646,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
||||||
case GGML_OP_LEAKY_RELU:
|
case GGML_OP_LEAKY_RELU:
|
||||||
case GGML_OP_FLASH_ATTN_EXT:
|
case GGML_OP_FLASH_ATTN_EXT:
|
||||||
case GGML_OP_OPT_STEP_ADAMW:
|
case GGML_OP_OPT_STEP_ADAMW:
|
||||||
|
case GGML_OP_OPT_STEP_SGD:
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(node->op) << std::endl;
|
std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(node->op) << std::endl;
|
||||||
|
@ -9692,6 +9711,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
||||||
case GGML_OP_CONV_2D:
|
case GGML_OP_CONV_2D:
|
||||||
case GGML_OP_CONV_2D_DW:
|
case GGML_OP_CONV_2D_DW:
|
||||||
case GGML_OP_LEAKY_RELU:
|
case GGML_OP_LEAKY_RELU:
|
||||||
|
case GGML_OP_OPT_STEP_SGD:
|
||||||
{
|
{
|
||||||
// These operations all go through ggml_vk_op_f32, so short-circuit and
|
// These operations all go through ggml_vk_op_f32, so short-circuit and
|
||||||
// do the only thing needed for the dryrun.
|
// do the only thing needed for the dryrun.
|
||||||
|
@ -9941,6 +9961,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
||||||
case GGML_OP_OPT_STEP_ADAMW:
|
case GGML_OP_OPT_STEP_ADAMW:
|
||||||
ggml_vk_opt_step_adamw(ctx, compute_ctx, node, dryrun);
|
ggml_vk_opt_step_adamw(ctx, compute_ctx, node, dryrun);
|
||||||
|
|
||||||
|
break;
|
||||||
|
|
||||||
|
case GGML_OP_OPT_STEP_SGD:
|
||||||
|
ggml_vk_opt_step_sgd(ctx, compute_ctx, src0, src1, src2, node, dryrun);
|
||||||
|
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
|
@ -10044,8 +10069,8 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
|
||||||
case GGML_OP_REPEAT:
|
case GGML_OP_REPEAT:
|
||||||
case GGML_OP_REPEAT_BACK:
|
case GGML_OP_REPEAT_BACK:
|
||||||
case GGML_OP_OPT_STEP_ADAMW:
|
case GGML_OP_OPT_STEP_ADAMW:
|
||||||
|
case GGML_OP_OPT_STEP_SGD:
|
||||||
buf = tensor->buffer;
|
buf = tensor->buffer;
|
||||||
|
|
||||||
break;
|
break;
|
||||||
case GGML_OP_UNARY:
|
case GGML_OP_UNARY:
|
||||||
switch (ggml_get_unary_op(tensor)) {
|
switch (ggml_get_unary_op(tensor)) {
|
||||||
|
@ -11184,6 +11209,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||||
case GGML_OP_SIN:
|
case GGML_OP_SIN:
|
||||||
case GGML_OP_COS:
|
case GGML_OP_COS:
|
||||||
case GGML_OP_CLAMP:
|
case GGML_OP_CLAMP:
|
||||||
|
case GGML_OP_LEAKY_RELU:
|
||||||
|
case GGML_OP_OPT_STEP_ADAMW:
|
||||||
|
case GGML_OP_OPT_STEP_SGD:
|
||||||
return op->src[0]->type == GGML_TYPE_F32;
|
return op->src[0]->type == GGML_TYPE_F32;
|
||||||
case GGML_OP_UPSCALE:
|
case GGML_OP_UPSCALE:
|
||||||
case GGML_OP_ACC:
|
case GGML_OP_ACC:
|
||||||
|
@ -11205,8 +11233,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||||
case GGML_OP_POOL_2D:
|
case GGML_OP_POOL_2D:
|
||||||
case GGML_OP_RWKV_WKV6:
|
case GGML_OP_RWKV_WKV6:
|
||||||
case GGML_OP_RWKV_WKV7:
|
case GGML_OP_RWKV_WKV7:
|
||||||
case GGML_OP_LEAKY_RELU:
|
|
||||||
case GGML_OP_OPT_STEP_ADAMW:
|
|
||||||
return true;
|
return true;
|
||||||
case GGML_OP_CONV_TRANSPOSE_1D:
|
case GGML_OP_CONV_TRANSPOSE_1D:
|
||||||
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
|
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
|
||||||
|
@ -11804,6 +11830,10 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
|
||||||
src_clone[0]->flags = src0->flags;
|
src_clone[0]->flags = src0->flags;
|
||||||
tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1],
|
tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1],
|
||||||
src_clone[2], src_clone[3], src_clone[4]);
|
src_clone[2], src_clone[3], src_clone[4]);
|
||||||
|
} else if (tensor->op == GGML_OP_OPT_STEP_SGD) {
|
||||||
|
src_clone[0]->flags = src0->flags;
|
||||||
|
tensor_clone = ggml_opt_step_sgd(ggml_ctx, src_clone[0], src_clone[1],
|
||||||
|
src_clone[2]);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
|
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
|
||||||
|
|
22
ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp
Normal file
22
ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp
Normal file
|
@ -0,0 +1,22 @@
|
||||||
|
#version 450
|
||||||
|
|
||||||
|
#include "generic_head.comp"
|
||||||
|
|
||||||
|
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
|
layout (binding = 0) buffer X {A_TYPE data_x[];};
|
||||||
|
layout (binding = 1) readonly buffer G {A_TYPE data_grad[];};
|
||||||
|
layout (binding = 2) readonly buffer P {float data_params[2];};
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
|
||||||
|
|
||||||
|
if (i >= p.KX) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const float alpha = data_params[0];
|
||||||
|
const float keep = 1.f - alpha * data_params[1];
|
||||||
|
|
||||||
|
data_x[i] = data_x[i] * keep - alpha * data_grad[i];
|
||||||
|
}
|
|
@ -671,6 +671,7 @@ void process_shaders() {
|
||||||
string_to_spv("rwkv_wkv7_f32", "wkv7.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
string_to_spv("rwkv_wkv7_f32", "wkv7.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
||||||
|
|
||||||
string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
||||||
|
string_to_spv("opt_step_sgd_f32", "opt_step_sgd.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
||||||
|
|
||||||
string_to_spv("conv2d_f32_unroll", "conv2d_mm.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", "[[unroll]]"}});
|
string_to_spv("conv2d_f32_unroll", "conv2d_mm.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", "[[unroll]]"}});
|
||||||
string_to_spv("conv2d_f16_f32_unroll", "conv2d_mm.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", "[[unroll]]"}});
|
string_to_spv("conv2d_f16_f32_unroll", "conv2d_mm.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", "[[unroll]]"}});
|
||||||
|
|
114
ggml/src/ggml.c
114
ggml/src/ggml.c
|
@ -1028,11 +1028,12 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||||
"CROSS_ENTROPY_LOSS",
|
"CROSS_ENTROPY_LOSS",
|
||||||
"CROSS_ENTROPY_LOSS_BACK",
|
"CROSS_ENTROPY_LOSS_BACK",
|
||||||
"OPT_STEP_ADAMW",
|
"OPT_STEP_ADAMW",
|
||||||
|
"OPT_STEP_SGD",
|
||||||
|
|
||||||
"GLU",
|
"GLU",
|
||||||
};
|
};
|
||||||
|
|
||||||
static_assert(GGML_OP_COUNT == 87, "GGML_OP_COUNT != 87");
|
static_assert(GGML_OP_COUNT == 88, "GGML_OP_COUNT != 88");
|
||||||
|
|
||||||
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||||
"none",
|
"none",
|
||||||
|
@ -1129,15 +1130,15 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||||
"cross_entropy_loss(x,y)",
|
"cross_entropy_loss(x,y)",
|
||||||
"cross_entropy_loss_back(x,y)",
|
"cross_entropy_loss_back(x,y)",
|
||||||
"adamw(x)",
|
"adamw(x)",
|
||||||
|
"sgd(x)",
|
||||||
|
|
||||||
"glu(x)",
|
"glu(x)",
|
||||||
};
|
};
|
||||||
|
|
||||||
static_assert(GGML_OP_COUNT == 87, "GGML_OP_COUNT != 87");
|
static_assert(GGML_OP_COUNT == 88, "GGML_OP_COUNT != 88");
|
||||||
|
|
||||||
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
||||||
|
|
||||||
|
|
||||||
static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
|
static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
|
||||||
"ABS",
|
"ABS",
|
||||||
"SGN",
|
"SGN",
|
||||||
|
@ -3901,6 +3902,7 @@ static struct ggml_tensor * ggml_rope_impl(
|
||||||
struct ggml_tensor * b,
|
struct ggml_tensor * b,
|
||||||
struct ggml_tensor * c,
|
struct ggml_tensor * c,
|
||||||
int n_dims,
|
int n_dims,
|
||||||
|
int sections[GGML_MROPE_SECTIONS],
|
||||||
int mode,
|
int mode,
|
||||||
int n_ctx_orig,
|
int n_ctx_orig,
|
||||||
float freq_base,
|
float freq_base,
|
||||||
|
@ -3914,15 +3916,19 @@ static struct ggml_tensor * ggml_rope_impl(
|
||||||
|
|
||||||
GGML_ASSERT(ggml_is_vector(b));
|
GGML_ASSERT(ggml_is_vector(b));
|
||||||
GGML_ASSERT(b->type == GGML_TYPE_I32);
|
GGML_ASSERT(b->type == GGML_TYPE_I32);
|
||||||
|
|
||||||
|
bool mrope_used = mode & GGML_ROPE_TYPE_MROPE;
|
||||||
|
if (mrope_used) {
|
||||||
|
GGML_ASSERT(a->ne[2] * 4 == b->ne[0]); // mrope expecting 4 position ids per token
|
||||||
|
} else {
|
||||||
GGML_ASSERT(a->ne[2] == b->ne[0]);
|
GGML_ASSERT(a->ne[2] == b->ne[0]);
|
||||||
|
}
|
||||||
|
|
||||||
if (c) {
|
if (c) {
|
||||||
GGML_ASSERT(c->type == GGML_TYPE_F32);
|
GGML_ASSERT(c->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT(c->ne[0] >= n_dims / 2);
|
GGML_ASSERT(c->ne[0] >= n_dims / 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
int sections[4] = {0, 0, 0, 0};
|
|
||||||
|
|
||||||
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
||||||
|
|
||||||
int32_t params[15] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
|
int32_t params[15] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
|
||||||
|
@ -3932,7 +3938,11 @@ static struct ggml_tensor * ggml_rope_impl(
|
||||||
memcpy(params + 8, &attn_factor, sizeof(float));
|
memcpy(params + 8, &attn_factor, sizeof(float));
|
||||||
memcpy(params + 9, &beta_fast, sizeof(float));
|
memcpy(params + 9, &beta_fast, sizeof(float));
|
||||||
memcpy(params + 10, &beta_slow, sizeof(float));
|
memcpy(params + 10, &beta_slow, sizeof(float));
|
||||||
memcpy(params + 11, §ions, sizeof(int)*4);
|
if (mrope_used) {
|
||||||
|
memcpy(params + 11, sections, sizeof(int32_t) * GGML_MROPE_SECTIONS);
|
||||||
|
} else {
|
||||||
|
memset(params + 11, 0, sizeof(int32_t) * GGML_MROPE_SECTIONS);
|
||||||
|
}
|
||||||
ggml_set_op_params(result, params, sizeof(params));
|
ggml_set_op_params(result, params, sizeof(params));
|
||||||
|
|
||||||
result->op = GGML_OP_ROPE;
|
result->op = GGML_OP_ROPE;
|
||||||
|
@ -3950,7 +3960,7 @@ struct ggml_tensor * ggml_rope(
|
||||||
int n_dims,
|
int n_dims,
|
||||||
int mode) {
|
int mode) {
|
||||||
return ggml_rope_impl(
|
return ggml_rope_impl(
|
||||||
ctx, a, b, NULL, n_dims, mode, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, false
|
ctx, a, b, NULL, n_dims, NULL, mode, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, false
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3960,7 +3970,7 @@ struct ggml_tensor * ggml_rope_multi(
|
||||||
struct ggml_tensor * b,
|
struct ggml_tensor * b,
|
||||||
struct ggml_tensor * c,
|
struct ggml_tensor * c,
|
||||||
int n_dims,
|
int n_dims,
|
||||||
int sections[4],
|
int sections[GGML_MROPE_SECTIONS],
|
||||||
int mode,
|
int mode,
|
||||||
int n_ctx_orig,
|
int n_ctx_orig,
|
||||||
float freq_base,
|
float freq_base,
|
||||||
|
@ -3969,36 +3979,31 @@ struct ggml_tensor * ggml_rope_multi(
|
||||||
float attn_factor,
|
float attn_factor,
|
||||||
float beta_fast,
|
float beta_fast,
|
||||||
float beta_slow) {
|
float beta_slow) {
|
||||||
// Multimodal Rotary Position Embedding
|
return ggml_rope_impl(
|
||||||
GGML_ASSERT((mode & 1) == 0 && "mode & 1 == 1 is no longer supported");
|
ctx, a, b, c, n_dims, sections, mode, n_ctx_orig, freq_base, freq_scale,
|
||||||
|
ext_factor, attn_factor, beta_fast, beta_slow, false
|
||||||
GGML_ASSERT(ggml_is_vector(b));
|
);
|
||||||
GGML_ASSERT(b->type == GGML_TYPE_I32);
|
|
||||||
GGML_ASSERT(a->ne[2] * 4 == b->ne[0]); // mrope expecting 4 position ids per token
|
|
||||||
|
|
||||||
if (c) {
|
|
||||||
GGML_ASSERT(c->type == GGML_TYPE_F32);
|
|
||||||
GGML_ASSERT(c->ne[0] >= n_dims / 2);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
|
struct ggml_tensor * ggml_rope_multi_inplace(
|
||||||
|
struct ggml_context * ctx,
|
||||||
int32_t params[11 + 4] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
|
struct ggml_tensor * a,
|
||||||
memcpy(params + 5, &freq_base, sizeof(float));
|
struct ggml_tensor * b,
|
||||||
memcpy(params + 6, &freq_scale, sizeof(float));
|
struct ggml_tensor * c,
|
||||||
memcpy(params + 7, &ext_factor, sizeof(float));
|
int n_dims,
|
||||||
memcpy(params + 8, &attn_factor, sizeof(float));
|
int sections[GGML_MROPE_SECTIONS],
|
||||||
memcpy(params + 9, &beta_fast, sizeof(float));
|
int mode,
|
||||||
memcpy(params + 10, &beta_slow, sizeof(float));
|
int n_ctx_orig,
|
||||||
memcpy(¶ms[11], sections, sizeof(int)*4);
|
float freq_base,
|
||||||
ggml_set_op_params(result, params, sizeof(params));
|
float freq_scale,
|
||||||
|
float ext_factor,
|
||||||
result->op = GGML_OP_ROPE;
|
float attn_factor,
|
||||||
result->src[0] = a;
|
float beta_fast,
|
||||||
result->src[1] = b;
|
float beta_slow) {
|
||||||
result->src[2] = c;
|
return ggml_rope_impl(
|
||||||
|
ctx, a, b, c, n_dims, sections, mode, n_ctx_orig, freq_base, freq_scale,
|
||||||
return result;
|
ext_factor, attn_factor, beta_fast, beta_slow, true
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor * ggml_rope_inplace(
|
struct ggml_tensor * ggml_rope_inplace(
|
||||||
|
@ -4008,7 +4013,7 @@ struct ggml_tensor * ggml_rope_inplace(
|
||||||
int n_dims,
|
int n_dims,
|
||||||
int mode) {
|
int mode) {
|
||||||
return ggml_rope_impl(
|
return ggml_rope_impl(
|
||||||
ctx, a, b, NULL, n_dims, mode, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, true
|
ctx, a, b, NULL, n_dims, NULL, mode, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, true
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -4027,7 +4032,7 @@ struct ggml_tensor * ggml_rope_ext(
|
||||||
float beta_fast,
|
float beta_fast,
|
||||||
float beta_slow) {
|
float beta_slow) {
|
||||||
return ggml_rope_impl(
|
return ggml_rope_impl(
|
||||||
ctx, a, b, c, n_dims, mode, n_ctx_orig, freq_base, freq_scale,
|
ctx, a, b, c, n_dims, NULL, mode, n_ctx_orig, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow, false
|
ext_factor, attn_factor, beta_fast, beta_slow, false
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
@ -4047,7 +4052,7 @@ struct ggml_tensor * ggml_rope_ext_inplace(
|
||||||
float beta_fast,
|
float beta_fast,
|
||||||
float beta_slow) {
|
float beta_slow) {
|
||||||
return ggml_rope_impl(
|
return ggml_rope_impl(
|
||||||
ctx, a, b, c, n_dims, mode, n_ctx_orig, freq_base, freq_scale,
|
ctx, a, b, c, n_dims, NULL, mode, n_ctx_orig, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow, true
|
ext_factor, attn_factor, beta_fast, beta_slow, true
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
@ -4066,7 +4071,7 @@ struct ggml_tensor * ggml_rope_custom(
|
||||||
float beta_fast,
|
float beta_fast,
|
||||||
float beta_slow) {
|
float beta_slow) {
|
||||||
return ggml_rope_impl(
|
return ggml_rope_impl(
|
||||||
ctx, a, b, NULL, n_dims, mode, n_ctx_orig, freq_base, freq_scale,
|
ctx, a, b, NULL, n_dims, NULL, mode, n_ctx_orig, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow, false
|
ext_factor, attn_factor, beta_fast, beta_slow, false
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
@ -4085,7 +4090,7 @@ struct ggml_tensor * ggml_rope_custom_inplace(
|
||||||
float beta_fast,
|
float beta_fast,
|
||||||
float beta_slow) {
|
float beta_slow) {
|
||||||
return ggml_rope_impl(
|
return ggml_rope_impl(
|
||||||
ctx, a, b, NULL, n_dims, mode, n_ctx_orig, freq_base, freq_scale,
|
ctx, a, b, NULL, n_dims, NULL, mode, n_ctx_orig, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow, true
|
ext_factor, attn_factor, beta_fast, beta_slow, true
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
@ -4283,14 +4288,13 @@ struct ggml_tensor * ggml_conv_1d_dw(
|
||||||
int s0,
|
int s0,
|
||||||
int p0,
|
int p0,
|
||||||
int d0) {
|
int d0) {
|
||||||
struct ggml_tensor * new_a = ggml_reshape_4d(ctx, a, a->ne[0], 1, a->ne[1], a->ne[2]);
|
|
||||||
struct ggml_tensor * new_b = ggml_reshape_4d(ctx, b, b->ne[0], 1, b->ne[1], b->ne[2]);
|
struct ggml_tensor * new_b = ggml_reshape_4d(ctx, b, b->ne[0], 1, b->ne[1], b->ne[2]);
|
||||||
|
|
||||||
struct ggml_tensor * im2col = ggml_im2col(ctx, new_a, new_b, s0, 0, p0, 0, d0, 0, false, GGML_TYPE_F16);
|
struct ggml_tensor * im2col = ggml_im2col(ctx, a, new_b, s0, 0, p0, 0, d0, 0, false, GGML_TYPE_F16);
|
||||||
|
|
||||||
struct ggml_tensor * result = ggml_mul_mat(ctx, im2col, a);
|
struct ggml_tensor * result = ggml_mul_mat(ctx, im2col, a);
|
||||||
|
|
||||||
result = ggml_reshape_3d(ctx, result, b->ne[0], b->ne[1], 1);
|
result = ggml_reshape_3d(ctx, result, result->ne[0], result->ne[2], 1);
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
@ -5618,6 +5622,28 @@ struct ggml_tensor * ggml_opt_step_adamw(
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// opt_step_sgd
|
||||||
|
|
||||||
|
struct ggml_tensor * ggml_opt_step_sgd(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
struct ggml_tensor * grad,
|
||||||
|
struct ggml_tensor * params) {
|
||||||
|
GGML_ASSERT(a->flags & GGML_TENSOR_FLAG_PARAM);
|
||||||
|
GGML_ASSERT(ggml_are_same_shape(a, grad));
|
||||||
|
GGML_ASSERT(params->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(ggml_nelements(params) == 2);
|
||||||
|
|
||||||
|
struct ggml_tensor * result = ggml_view_tensor(ctx, a);
|
||||||
|
|
||||||
|
result->op = GGML_OP_OPT_STEP_SGD;
|
||||||
|
result->src[0] = a;
|
||||||
|
result->src[1] = grad;
|
||||||
|
result->src[2] = params;
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
struct ggml_hash_set ggml_hash_set_new(size_t size) {
|
struct ggml_hash_set ggml_hash_set_new(size_t size) {
|
||||||
|
|
|
@ -873,6 +873,29 @@ extern "C" {
|
||||||
size_t n_token_capacity,
|
size_t n_token_capacity,
|
||||||
size_t * n_token_count_out);
|
size_t * n_token_count_out);
|
||||||
|
|
||||||
|
#define LLAMA_STATE_SEQ_FLAGS_SWA_ONLY 1
|
||||||
|
|
||||||
|
typedef uint32_t llama_state_seq_flags;
|
||||||
|
|
||||||
|
LLAMA_API size_t llama_state_seq_get_size_ext(
|
||||||
|
struct llama_context * ctx,
|
||||||
|
llama_seq_id seq_id,
|
||||||
|
llama_state_seq_flags flags);
|
||||||
|
|
||||||
|
LLAMA_API size_t llama_state_seq_get_data_ext(
|
||||||
|
struct llama_context * ctx,
|
||||||
|
uint8_t * dst,
|
||||||
|
size_t size,
|
||||||
|
llama_seq_id seq_id,
|
||||||
|
llama_state_seq_flags flags);
|
||||||
|
|
||||||
|
LLAMA_API size_t llama_state_seq_set_data_ext(
|
||||||
|
struct llama_context * ctx,
|
||||||
|
const uint8_t * src,
|
||||||
|
size_t size,
|
||||||
|
llama_seq_id dest_seq_id,
|
||||||
|
llama_state_seq_flags flags);
|
||||||
|
|
||||||
//
|
//
|
||||||
// Decoding
|
// Decoding
|
||||||
//
|
//
|
||||||
|
@ -1440,6 +1463,8 @@ extern "C" {
|
||||||
|
|
||||||
ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters
|
ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters
|
||||||
void * get_opt_pars_ud; // userdata for calculating optimizer parameters
|
void * get_opt_pars_ud; // userdata for calculating optimizer parameters
|
||||||
|
|
||||||
|
enum ggml_opt_optimizer_type optimizer_type;
|
||||||
};
|
};
|
||||||
|
|
||||||
LLAMA_API void llama_opt_init(struct llama_context * lctx, struct llama_model * model, struct llama_opt_params lopt_params);
|
LLAMA_API void llama_opt_init(struct llama_context * lctx, struct llama_model * model, struct llama_opt_params lopt_params);
|
||||||
|
|
|
@ -476,7 +476,7 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
|
||||||
|
|
||||||
llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch, bool sequential) {
|
llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch, bool sequential) {
|
||||||
if (sequential && has_cpl) {
|
if (sequential && has_cpl) {
|
||||||
LLAMA_LOG_ERROR("%s: sequential split is not supported when there are coupled sequences in the input batch\n", __func__);
|
LLAMA_LOG_ERROR("%s: sequential split is not supported when there are coupled sequences in the input batch (you may need to use the -kvu flag)\n", __func__);
|
||||||
|
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
|
@ -1657,30 +1657,30 @@ size_t llama_context::state_set_data(const uint8_t * src, size_t size) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t llama_context::state_seq_get_size(llama_seq_id seq_id) {
|
size_t llama_context::state_seq_get_size(llama_seq_id seq_id, llama_state_seq_flags flags) {
|
||||||
llama_io_write_dummy io;
|
llama_io_write_dummy io;
|
||||||
try {
|
try {
|
||||||
return state_seq_write_data(io, seq_id);
|
return state_seq_write_data(io, seq_id, flags);
|
||||||
} catch (const std::exception & err) {
|
} catch (const std::exception & err) {
|
||||||
LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what());
|
LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what());
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t llama_context::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size) {
|
size_t llama_context::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size, llama_state_seq_flags flags) {
|
||||||
llama_io_write_buffer io(dst, size);
|
llama_io_write_buffer io(dst, size);
|
||||||
try {
|
try {
|
||||||
return state_seq_write_data(io, seq_id);
|
return state_seq_write_data(io, seq_id, flags);
|
||||||
} catch (const std::exception & err) {
|
} catch (const std::exception & err) {
|
||||||
LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what());
|
LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what());
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t llama_context::state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size) {
|
size_t llama_context::state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size, llama_state_seq_flags flags) {
|
||||||
llama_io_read_buffer io(src, size);
|
llama_io_read_buffer io(src, size);
|
||||||
try {
|
try {
|
||||||
return state_seq_read_data(io, seq_id);
|
return state_seq_read_data(io, seq_id, flags);
|
||||||
} catch (const std::exception & err) {
|
} catch (const std::exception & err) {
|
||||||
LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what());
|
LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what());
|
||||||
return 0;
|
return 0;
|
||||||
|
@ -1778,7 +1778,7 @@ size_t llama_context::state_seq_load_file(llama_seq_id seq_id, const char * file
|
||||||
{
|
{
|
||||||
const size_t state_size = file.size() - file.tell();
|
const size_t state_size = file.size() - file.tell();
|
||||||
llama_io_read_file io(&file);
|
llama_io_read_file io(&file);
|
||||||
const size_t nread = state_seq_read_data(io, seq_id);
|
const size_t nread = state_seq_read_data(io, seq_id, 0);
|
||||||
if (!nread) {
|
if (!nread) {
|
||||||
LLAMA_LOG_ERROR("%s: failed to restore sequence state\n", __func__);
|
LLAMA_LOG_ERROR("%s: failed to restore sequence state\n", __func__);
|
||||||
return 0;
|
return 0;
|
||||||
|
@ -1802,7 +1802,7 @@ size_t llama_context::state_seq_save_file(llama_seq_id seq_id, const char * file
|
||||||
|
|
||||||
// save the context state using stream saving
|
// save the context state using stream saving
|
||||||
llama_io_write_file io(&file);
|
llama_io_write_file io(&file);
|
||||||
state_seq_write_data(io, seq_id);
|
state_seq_write_data(io, seq_id, 0);
|
||||||
|
|
||||||
const size_t res = file.tell();
|
const size_t res = file.tell();
|
||||||
GGML_ASSERT(res == sizeof(uint32_t) * 3 + sizeof(llama_token) * n_token_count + io.n_bytes());
|
GGML_ASSERT(res == sizeof(uint32_t) * 3 + sizeof(llama_token) * n_token_count + io.n_bytes());
|
||||||
|
@ -1971,21 +1971,21 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
|
||||||
return io.n_bytes();
|
return io.n_bytes();
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id) {
|
size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
|
||||||
GGML_UNUSED(seq_id);
|
GGML_UNUSED(seq_id);
|
||||||
|
|
||||||
if (memory) {
|
if (memory) {
|
||||||
memory->state_write(io, seq_id);
|
memory->state_write(io, seq_id, flags);
|
||||||
}
|
}
|
||||||
|
|
||||||
return io.n_bytes();
|
return io.n_bytes();
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id) {
|
size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
|
||||||
GGML_UNUSED(seq_id);
|
GGML_UNUSED(seq_id);
|
||||||
|
|
||||||
if (memory) {
|
if (memory) {
|
||||||
memory->state_read(io, seq_id);
|
memory->state_read(io, seq_id, flags);
|
||||||
}
|
}
|
||||||
|
|
||||||
return io.n_bytes();
|
return io.n_bytes();
|
||||||
|
@ -2048,7 +2048,7 @@ void llama_context::opt_init(struct llama_model * model, struct llama_opt_params
|
||||||
opt_params.opt_period = n_batch / n_ubatch;
|
opt_params.opt_period = n_batch / n_ubatch;
|
||||||
opt_params.get_opt_pars = lopt_params.get_opt_pars;
|
opt_params.get_opt_pars = lopt_params.get_opt_pars;
|
||||||
opt_params.get_opt_pars_ud = lopt_params.get_opt_pars_ud;
|
opt_params.get_opt_pars_ud = lopt_params.get_opt_pars_ud;
|
||||||
|
opt_params.optimizer = lopt_params.optimizer_type;
|
||||||
opt_ctx = ggml_opt_init(opt_params);
|
opt_ctx = ggml_opt_init(opt_params);
|
||||||
|
|
||||||
llama_opt_param_filter param_filter = lopt_params.param_filter;
|
llama_opt_param_filter param_filter = lopt_params.param_filter;
|
||||||
|
@ -2801,19 +2801,31 @@ bool llama_state_save_file(llama_context * ctx, const char * path_session, const
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t llama_state_seq_get_size(llama_context * ctx, llama_seq_id seq_id) {
|
size_t llama_state_seq_get_size(llama_context * ctx, llama_seq_id seq_id) {
|
||||||
return ctx->state_seq_get_size(seq_id);
|
return llama_state_seq_get_size_ext(ctx, seq_id, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t llama_state_seq_get_data(llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id) {
|
size_t llama_state_seq_get_data(llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id) {
|
||||||
ctx->synchronize();
|
return llama_state_seq_get_data_ext(ctx, dst, size, seq_id, 0);
|
||||||
|
|
||||||
return ctx->state_seq_get_data(seq_id, dst, size);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t llama_state_seq_set_data(llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id) {
|
size_t llama_state_seq_set_data(llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id) {
|
||||||
|
return llama_state_seq_set_data_ext(ctx, src, size, seq_id, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t llama_state_seq_get_size_ext(llama_context * ctx, llama_seq_id seq_id, llama_state_seq_flags flags) {
|
||||||
|
return ctx->state_seq_get_size(seq_id, flags);
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t llama_state_seq_get_data_ext(llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id, llama_state_seq_flags flags) {
|
||||||
ctx->synchronize();
|
ctx->synchronize();
|
||||||
|
|
||||||
return ctx->state_seq_set_data(seq_id, src, size);
|
return ctx->state_seq_get_data(seq_id, dst, size, flags);
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t llama_state_seq_set_data_ext(llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id, llama_state_seq_flags flags) {
|
||||||
|
ctx->synchronize();
|
||||||
|
|
||||||
|
return ctx->state_seq_set_data(seq_id, src, size, flags);
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t llama_state_seq_save_file(llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) {
|
size_t llama_state_seq_save_file(llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) {
|
||||||
|
|
|
@ -111,9 +111,9 @@ struct llama_context {
|
||||||
size_t state_get_data( uint8_t * dst, size_t size);
|
size_t state_get_data( uint8_t * dst, size_t size);
|
||||||
size_t state_set_data(const uint8_t * src, size_t size);
|
size_t state_set_data(const uint8_t * src, size_t size);
|
||||||
|
|
||||||
size_t state_seq_get_size(llama_seq_id seq_id);
|
size_t state_seq_get_size(llama_seq_id seq_id, llama_state_seq_flags flags);
|
||||||
size_t state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size);
|
size_t state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size, llama_state_seq_flags flags);
|
||||||
size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size);
|
size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size, llama_state_seq_flags flags);
|
||||||
|
|
||||||
bool state_load_file(
|
bool state_load_file(
|
||||||
const char * filepath,
|
const char * filepath,
|
||||||
|
@ -152,6 +152,7 @@ struct llama_context {
|
||||||
|
|
||||||
void opt_init(struct llama_model * model, struct llama_opt_params lopt_params);
|
void opt_init(struct llama_model * model, struct llama_opt_params lopt_params);
|
||||||
|
|
||||||
|
// TODO: more flexible combinations of logical/physical batch size and context size
|
||||||
void opt_epoch(
|
void opt_epoch(
|
||||||
ggml_opt_dataset_t dataset,
|
ggml_opt_dataset_t dataset,
|
||||||
ggml_opt_result_t result_train,
|
ggml_opt_result_t result_train,
|
||||||
|
@ -212,8 +213,8 @@ private:
|
||||||
size_t state_write_data(llama_io_write_i & io);
|
size_t state_write_data(llama_io_write_i & io);
|
||||||
size_t state_read_data (llama_io_read_i & io);
|
size_t state_read_data (llama_io_read_i & io);
|
||||||
|
|
||||||
size_t state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id);
|
size_t state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags);
|
||||||
size_t state_seq_read_data (llama_io_read_i & io, llama_seq_id seq_id);
|
size_t state_seq_read_data (llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags);
|
||||||
|
|
||||||
//
|
//
|
||||||
// members
|
// members
|
||||||
|
|
|
@ -198,14 +198,20 @@ bool llama_kv_cache_unified_iswa::get_can_shift() const {
|
||||||
return kv_base->get_size() == kv_swa->get_size();
|
return kv_base->get_size() == kv_swa->get_size();
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache_unified_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
|
void llama_kv_cache_unified_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
|
||||||
kv_base->state_write(io, seq_id);
|
if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) {
|
||||||
kv_swa ->state_write(io, seq_id);
|
kv_base->state_write(io, seq_id, flags);
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache_unified_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
|
kv_swa->state_write(io, seq_id, flags);
|
||||||
kv_base->state_read(io, seq_id);
|
}
|
||||||
kv_swa ->state_read(io, seq_id);
|
|
||||||
|
void llama_kv_cache_unified_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
|
||||||
|
if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) {
|
||||||
|
kv_base->state_read(io, seq_id, flags);
|
||||||
|
}
|
||||||
|
|
||||||
|
kv_swa->state_read(io, seq_id, flags);
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_base() const {
|
llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_base() const {
|
||||||
|
|
|
@ -56,8 +56,8 @@ public:
|
||||||
|
|
||||||
// state write/load
|
// state write/load
|
||||||
|
|
||||||
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
|
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
|
||||||
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
|
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
|
||||||
|
|
||||||
//
|
//
|
||||||
// llama_kv_cache_unified_iswa specific API
|
// llama_kv_cache_unified_iswa specific API
|
||||||
|
|
|
@ -1828,7 +1828,9 @@ bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
|
void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
|
||||||
|
GGML_UNUSED(flags);
|
||||||
|
|
||||||
io.write(&n_stream, sizeof(n_stream));
|
io.write(&n_stream, sizeof(n_stream));
|
||||||
|
|
||||||
for (uint32_t s = 0; s < n_stream; ++s) {
|
for (uint32_t s = 0; s < n_stream; ++s) {
|
||||||
|
@ -1879,7 +1881,9 @@ void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
|
void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
|
||||||
|
GGML_UNUSED(flags);
|
||||||
|
|
||||||
GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()));
|
GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()));
|
||||||
|
|
||||||
uint32_t n_stream_cur;
|
uint32_t n_stream_cur;
|
||||||
|
|
|
@ -136,8 +136,8 @@ public:
|
||||||
|
|
||||||
// state write/load
|
// state write/load
|
||||||
|
|
||||||
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
|
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
|
||||||
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
|
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
|
||||||
|
|
||||||
//
|
//
|
||||||
// llama_kv_cache_unified specific API
|
// llama_kv_cache_unified specific API
|
||||||
|
|
|
@ -165,12 +165,16 @@ llama_pos llama_memory_hybrid::seq_pos_max(llama_seq_id seq_id) const {
|
||||||
return std::min(mem_attn->seq_pos_max(seq_id), mem_recr->seq_pos_max(seq_id));
|
return std::min(mem_attn->seq_pos_max(seq_id), mem_recr->seq_pos_max(seq_id));
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_memory_hybrid::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
|
void llama_memory_hybrid::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
|
||||||
|
GGML_UNUSED(flags);
|
||||||
|
|
||||||
mem_attn->state_write(io, seq_id);
|
mem_attn->state_write(io, seq_id);
|
||||||
mem_recr->state_write(io, seq_id);
|
mem_recr->state_write(io, seq_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_memory_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
|
void llama_memory_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
|
||||||
|
GGML_UNUSED(flags);
|
||||||
|
|
||||||
mem_attn->state_read(io, seq_id);
|
mem_attn->state_read(io, seq_id);
|
||||||
mem_recr->state_read(io, seq_id);
|
mem_recr->state_read(io, seq_id);
|
||||||
}
|
}
|
||||||
|
|
|
@ -74,8 +74,8 @@ public:
|
||||||
|
|
||||||
// state write/load
|
// state write/load
|
||||||
|
|
||||||
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
|
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
|
||||||
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
|
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
|
||||||
|
|
||||||
//
|
//
|
||||||
// llama_memory_hybrid specific API
|
// llama_memory_hybrid specific API
|
||||||
|
|
|
@ -680,7 +680,9 @@ size_t llama_memory_recurrent::size_s_bytes() const {
|
||||||
return size_s_bytes;
|
return size_s_bytes;
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
|
void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
|
||||||
|
GGML_UNUSED(flags);
|
||||||
|
|
||||||
std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
|
std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
|
||||||
uint32_t cell_count = 0;
|
uint32_t cell_count = 0;
|
||||||
|
|
||||||
|
@ -718,7 +720,9 @@ void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq
|
||||||
state_write_data(io, cell_ranges);
|
state_write_data(io, cell_ranges);
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_memory_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
|
void llama_memory_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
|
||||||
|
GGML_UNUSED(flags);
|
||||||
|
|
||||||
uint32_t cell_count;
|
uint32_t cell_count;
|
||||||
io.read_to(&cell_count, sizeof(cell_count));
|
io.read_to(&cell_count, sizeof(cell_count));
|
||||||
|
|
||||||
|
|
|
@ -63,8 +63,8 @@ public:
|
||||||
|
|
||||||
// state write/load
|
// state write/load
|
||||||
|
|
||||||
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
|
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
|
||||||
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
|
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
|
||||||
|
|
||||||
uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot())
|
uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot())
|
||||||
uint32_t size = 0; // total number of cells, shared across all sequences
|
uint32_t size = 0; // total number of cells, shared across all sequences
|
||||||
|
|
|
@ -104,8 +104,8 @@ struct llama_memory_i {
|
||||||
// state write/read
|
// state write/read
|
||||||
//
|
//
|
||||||
|
|
||||||
virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const = 0;
|
virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const = 0;
|
||||||
virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) = 0;
|
virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
using llama_memory_ptr = std::unique_ptr<llama_memory_i>;
|
using llama_memory_ptr = std::unique_ptr<llama_memory_i>;
|
||||||
|
|
Binary file not shown.
|
@ -692,6 +692,13 @@ struct completion_token_output {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct swa_checkpoint {
|
||||||
|
llama_pos pos_min;
|
||||||
|
llama_pos pos_max;
|
||||||
|
|
||||||
|
std::vector<uint8_t> data;
|
||||||
|
};
|
||||||
|
|
||||||
struct server_task_result_cmpl_final : server_task_result {
|
struct server_task_result_cmpl_final : server_task_result {
|
||||||
int index = 0;
|
int index = 0;
|
||||||
|
|
||||||
|
@ -1336,6 +1343,8 @@ struct server_slot {
|
||||||
|
|
||||||
std::vector<completion_token_output> generated_token_probs;
|
std::vector<completion_token_output> generated_token_probs;
|
||||||
|
|
||||||
|
std::vector<swa_checkpoint> swa_checkpoints;
|
||||||
|
|
||||||
bool has_next_token = true;
|
bool has_next_token = true;
|
||||||
bool has_new_line = false;
|
bool has_new_line = false;
|
||||||
bool truncated = false;
|
bool truncated = false;
|
||||||
|
@ -2015,6 +2024,10 @@ struct server_context {
|
||||||
params_dft.cache_type_k = params_base.speculative.cache_type_k;
|
params_dft.cache_type_k = params_base.speculative.cache_type_k;
|
||||||
params_dft.cache_type_v = params_base.speculative.cache_type_v;
|
params_dft.cache_type_v = params_base.speculative.cache_type_v;
|
||||||
|
|
||||||
|
params_dft.cpuparams.n_threads = params_base.speculative.cpuparams.n_threads;
|
||||||
|
params_dft.cpuparams_batch.n_threads = params_base.speculative.cpuparams_batch.n_threads;
|
||||||
|
params_dft.tensor_buft_overrides = params_base.speculative.tensor_buft_overrides;
|
||||||
|
|
||||||
llama_init_dft = common_init_from_params(params_dft);
|
llama_init_dft = common_init_from_params(params_dft);
|
||||||
|
|
||||||
model_dft = llama_init_dft.model.get();
|
model_dft = llama_init_dft.model.get();
|
||||||
|
@ -3289,6 +3302,8 @@ struct server_context {
|
||||||
slot.n_past = 0;
|
slot.n_past = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const auto n_swa = llama_model_n_swa(model);
|
||||||
|
|
||||||
if (slot.n_past > 0 && slot.n_past < (int) slot.cache_tokens.size()) {
|
if (slot.n_past > 0 && slot.n_past < (int) slot.cache_tokens.size()) {
|
||||||
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
|
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
|
||||||
if (pos_min == -1) {
|
if (pos_min == -1) {
|
||||||
|
@ -3296,12 +3311,58 @@ struct server_context {
|
||||||
GGML_ABORT("pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237");
|
GGML_ABORT("pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237");
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto n_swa = llama_model_n_swa(model);
|
const auto pos_min_thold = std::max(0, slot.n_past - n_swa);
|
||||||
if (pos_min > std::max(0, slot.n_past - n_swa)) {
|
|
||||||
|
if (pos_min > pos_min_thold) {
|
||||||
SLT_WRN(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min, n_swa);
|
SLT_WRN(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min, n_swa);
|
||||||
|
|
||||||
|
// search for a SWA checkpoint
|
||||||
|
const auto it = std::find_if(
|
||||||
|
slot.swa_checkpoints.rbegin(),
|
||||||
|
slot.swa_checkpoints.rend(),
|
||||||
|
[&](const auto & cur) {
|
||||||
|
return cur.pos_min <= pos_min_thold;
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
bool do_reset = it == slot.swa_checkpoints.rend();
|
||||||
|
|
||||||
|
if (!do_reset) {
|
||||||
|
// restore the checkpoint
|
||||||
|
const size_t swa_size = it->data.size();
|
||||||
|
const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), swa_size, slot.id, LLAMA_STATE_SEQ_FLAGS_SWA_ONLY);
|
||||||
|
|
||||||
|
if (n != swa_size) {
|
||||||
|
SLT_ERR(slot, "failed to restore SWA checkpoint, pos_min = %d, pos_max = %d, size = %.3f MiB\n", it->pos_min, it->pos_max, (float) swa_size / 1024 / 1024);
|
||||||
|
do_reset = true;
|
||||||
|
} else {
|
||||||
|
slot.n_past = std::min(slot.n_past, it->pos_max);
|
||||||
|
|
||||||
|
SLT_WRN(slot, "SWA checkpoint restore, pos_min = %d, pos_max = %d, size = %.3f MiB\n", it->pos_min, it->pos_max, (float) swa_size / 1024 / 1024);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (do_reset) {
|
||||||
SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA, see %s)\n",
|
SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA, see %s)\n",
|
||||||
"https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
|
"https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
|
||||||
|
|
||||||
slot.n_past = 0;
|
slot.n_past = 0;
|
||||||
|
slot.swa_checkpoints.clear();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (n_swa > 0) {
|
||||||
|
const auto pos_min_thold = std::max(0, slot.n_past - n_swa);
|
||||||
|
|
||||||
|
// erase any checkpoints with pos_min > pos_min_thold
|
||||||
|
for (int i = (int) slot.swa_checkpoints.size() - 1; i >= 0; i--) {
|
||||||
|
const auto & cur = slot.swa_checkpoints[i];
|
||||||
|
if (cur.pos_min > pos_min_thold) {
|
||||||
|
slot.swa_checkpoints.erase(slot.swa_checkpoints.begin() + i);
|
||||||
|
|
||||||
|
SLT_WRN(slot, "SWA checkpoint erase, pos_min = %d, pos_max = %d, size = %.3f MiB\n", cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -3515,6 +3576,39 @@ struct server_context {
|
||||||
|
|
||||||
// prompt evaluated for next-token prediction
|
// prompt evaluated for next-token prediction
|
||||||
slot.state = SLOT_STATE_GENERATING;
|
slot.state = SLOT_STATE_GENERATING;
|
||||||
|
|
||||||
|
// make a checkpoint with the SWA memory
|
||||||
|
// checkpoints are needed only if we are not using "--swa-full"
|
||||||
|
if (llama_model_n_swa(model) > 0 && !params_base.swa_full && params_base.n_swa_checkpoints > 0) {
|
||||||
|
if (slot.swa_checkpoints.size() >= (size_t) params_base.n_swa_checkpoints) {
|
||||||
|
{
|
||||||
|
const auto & cur = slot.swa_checkpoints.back();
|
||||||
|
|
||||||
|
SLT_WRN(slot, "SWA checkpoint erase, pos_min = %d, pos_max = %d, size = %.3f MiB\n",
|
||||||
|
cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
|
||||||
|
}
|
||||||
|
|
||||||
|
slot.swa_checkpoints.erase(slot.swa_checkpoints.begin());
|
||||||
|
}
|
||||||
|
|
||||||
|
const size_t swa_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_SWA_ONLY);
|
||||||
|
|
||||||
|
auto & cur = slot.swa_checkpoints.emplace_back(swa_checkpoint{
|
||||||
|
/*.pos_min = */ llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id),
|
||||||
|
/*.pos_max = */ llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id),
|
||||||
|
/*.data = */ std::vector<uint8_t>(swa_size),
|
||||||
|
});
|
||||||
|
|
||||||
|
llama_state_seq_get_data_ext(ctx, cur.data.data(), swa_size, slot.id, LLAMA_STATE_SEQ_FLAGS_SWA_ONLY);
|
||||||
|
|
||||||
|
float size_total = 0.0f;
|
||||||
|
for (const auto & checkpoint : slot.swa_checkpoints) {
|
||||||
|
size_total += (float) checkpoint.data.size() / 1024 / 1024;
|
||||||
|
}
|
||||||
|
|
||||||
|
SLT_WRN(slot, "SWA checkpoint create, pos_min = %d, pos_max = %d, size = %.3f MiB, total = %d/%d (%.3f MiB)\n",
|
||||||
|
cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024, (int) slot.swa_checkpoints.size(), params_base.n_swa_checkpoints, size_total);
|
||||||
|
}
|
||||||
} else if (slot.state != SLOT_STATE_GENERATING) {
|
} else if (slot.state != SLOT_STATE_GENERATING) {
|
||||||
continue; // continue loop of slots
|
continue; // continue loop of slots
|
||||||
}
|
}
|
||||||
|
|
|
@ -130,7 +130,12 @@ export function filterThoughtFromMsgs(messages: APIMessage[]) {
|
||||||
role: msg.role,
|
role: msg.role,
|
||||||
content:
|
content:
|
||||||
msg.role === 'assistant'
|
msg.role === 'assistant'
|
||||||
? contentStr.split('</think>').at(-1)!.trim()
|
? contentStr
|
||||||
|
.split(
|
||||||
|
/<\/think>|<\|start\|>assistant<\|channel\|>final<\|message\|>/
|
||||||
|
)
|
||||||
|
.at(-1)!
|
||||||
|
.trim()
|
||||||
: contentStr,
|
: contentStr,
|
||||||
} as APIMessage;
|
} as APIMessage;
|
||||||
});
|
});
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue