Merge branch 'upstream' into concedo_experimental

# Conflicts:
#	.devops/nix/package.nix
#	.github/workflows/build.yml
#	.github/workflows/server.yml
#	CMakeLists.txt
#	Makefile
#	README.md
#	ggml-cuda.cu
#	tests/test-backend-ops.cpp
This commit is contained in:
Concedo 2024-05-19 17:55:20 +08:00
commit d5d5dda02b
34 changed files with 9731 additions and 4163 deletions

73
.github/labeler.yml vendored Normal file
View file

@ -0,0 +1,73 @@
# https://github.com/actions/labeler
SYCL:
- changed-files:
- any-glob-to-any-file:
- ggml-sycl.h
- ggml-sycl.cpp
- README-sycl.md
Nvidia GPU:
- changed-files:
- any-glob-to-any-file:
- ggml-cuda/**
Vulkan:
- changed-files:
- any-glob-to-any-file:
- ggml_vk_generate_shaders.py
- ggml-vulkan*
documentation:
- changed-files:
- any-glob-to-any-file:
- docs/**
- media/**
testing:
- changed-files:
- any-glob-to-any-file:
- tests/**
build:
- changed-files:
- any-glob-to-any-file:
- cmake/**
- CMakeLists.txt
- CMakePresets.json
- codecov.yml
examples:
- changed-files:
- any-glob-to-any-file: examples/**
devops:
- changed-files:
- any-glob-to-any-file:
- .devops/**
- .github/**
- ci/**
python:
- changed-files:
- any-glob-to-any-file:
- "**/*.py"
- requirements/**
- gguf-py/**
- .flake8
script:
- changed-files:
- any-glob-to-any-file:
- scripts/**
android:
- changed-files:
- any-glob-to-any-file:
- examples/llama.android/**
server:
- changed-files:
- any-glob-to-any-file:
- examples/server/**
ggml:
- changed-files:
- any-glob-to-any-file:
- ggml-*.c
- ggml-*.h
- ggml-cuda/**
nix:
- changed-files:
- any-glob-to-any-file:
- "**/*.nix"
- .github/workflows/nix-*.yml
- .devops/nix/nixpkgs-instances.nix

12
.github/workflows/labeler.yml vendored Normal file
View file

@ -0,0 +1,12 @@
name: "Pull Request Labeler"
on:
- pull_request_target
jobs:
labeler:
permissions:
contents: read
pull-requests: write
runs-on: ubuntu-latest
steps:
- uses: actions/labeler@v5

View file

@ -573,6 +573,10 @@ class Model:
vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size()) vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size())
tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)]
scores: list[float] = [-10000.0] * vocab_size
toktypes: list[int] = [SentencePieceTokenTypes.UNKNOWN] * vocab_size
for token_id in range(tokenizer.vocab_size()): for token_id in range(tokenizer.vocab_size()):
piece = tokenizer.IdToPiece(token_id) piece = tokenizer.IdToPiece(token_id)
text = piece.encode("utf-8") text = piece.encode("utf-8")
@ -588,21 +592,23 @@ class Model:
elif tokenizer.IsByte(token_id): elif tokenizer.IsByte(token_id):
toktype = SentencePieceTokenTypes.BYTE toktype = SentencePieceTokenTypes.BYTE
tokens.append(text) tokens[token_id] = text
scores.append(score) scores[token_id] = score
toktypes.append(toktype) toktypes[token_id] = toktype
added_tokens_file = self.dir_model / 'added_tokens.json' added_tokens_file = self.dir_model / 'added_tokens.json'
if added_tokens_file.is_file(): if added_tokens_file.is_file():
with open(added_tokens_file, "r", encoding="utf-8") as f: with open(added_tokens_file, "r", encoding="utf-8") as f:
added_tokens_json = json.load(f) added_tokens_json = json.load(f)
for key in added_tokens_json: for key in added_tokens_json:
key = key.encode("utf-8") token_id = added_tokens_json[key]
if key not in tokens: if (token_id >= vocab_size):
tokens.append(key) logger.warning(f'ignore token {token_id}: id is out of range, max={vocab_size - 1}')
scores.append(-1000.0) continue
toktypes.append(SentencePieceTokenTypes.USER_DEFINED)
tokens[token_id] = key.encode("utf-8")
scores[token_id] = -1000.0
toktypes[token_id] = SentencePieceTokenTypes.USER_DEFINED
if vocab_size > len(tokens): if vocab_size > len(tokens):
pad_count = vocab_size - len(tokens) pad_count = vocab_size - len(tokens)
@ -612,8 +618,6 @@ class Model:
scores.append(-1000.0) scores.append(-1000.0)
toktypes.append(SentencePieceTokenTypes.UNUSED) toktypes.append(SentencePieceTokenTypes.UNUSED)
assert len(tokens) == vocab_size
self.gguf_writer.add_tokenizer_model("llama") self.gguf_writer.add_tokenizer_model("llama")
self.gguf_writer.add_tokenizer_pre("default") self.gguf_writer.add_tokenizer_pre("default")
self.gguf_writer.add_token_list(tokens) self.gguf_writer.add_token_list(tokens)

View file

@ -12,15 +12,17 @@ cmake_minimum_required(VERSION 3.22.1)
# build script scope). # build script scope).
project("llama-android") project("llama-android")
include(FetchContent) #include(FetchContent)
FetchContent_Declare( #FetchContent_Declare(
llama # llama
GIT_REPOSITORY https://github.com/ggerganov/llama.cpp # GIT_REPOSITORY https://github.com/ggerganov/llama.cpp
GIT_TAG master # GIT_TAG ci-android
) #)
#
## Also provides "common"
#FetchContent_MakeAvailable(llama)
# Also provides "common" add_subdirectory(../../../../../../ please-work)
FetchContent_MakeAvailable(llama)
# Creates and names a library, sets it as either STATIC # Creates and names a library, sets it as either STATIC
# or SHARED, and provides the relative paths to its source code. # or SHARED, and provides the relative paths to its source code.

View file

@ -1426,7 +1426,7 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params
// Use all tasks // Use all tasks
tasks.resize(n_task); tasks.resize(n_task);
printf("%s: reading tasks", __func__); printf("%s: reading tasks", __func__);
int n_dot = n_task/100; int n_dot = std::max((int) n_task/100, 1);
int i = 0; int i = 0;
for (auto& task : tasks) { for (auto& task : tasks) {
++i; ++i;
@ -1676,7 +1676,7 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params
llama_batch_free(batch); llama_batch_free(batch);
if (n_done < 100) return; if (n_done < 100 && (params.multiple_choice_tasks != 0 && params.multiple_choice_tasks < (size_t)n_task)) return;
float p = 1.f*n_correct/n_done; float p = 1.f*n_correct/n_done;
float sigma = sqrt(p*(1-p)/(n_done-1)); float sigma = sqrt(p*(1-p)/(n_done-1));

View file

@ -56,6 +56,10 @@ static bool rpc_server_params_parse(int argc, char ** argv, rpc_server_params &
} else if (arg == "-h" || arg == "--help") { } else if (arg == "-h" || arg == "--help") {
print_usage(argc, argv, params); print_usage(argc, argv, params);
exit(0); exit(0);
} else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
print_usage(argc, argv, params);
exit(0);
} }
} }
return true; return true;

View file

@ -18,8 +18,8 @@ The project is under active development, and we are [looking for feedback and co
**Command line options:** **Command line options:**
- `-v`, `--verbose`: Enable verbose server output. When using the `/completion` endpoint, this includes the tokenized prompt, the full request and the full response. - `-v`, `--verbose`: Enable verbose server output. When using the `/completion` endpoint, this includes the tokenized prompt, the full request and the full response.
- `-t N`, `--threads N`: Set the number of threads to use during generation. Not used if model layers are offloaded to GPU. The server is using batching. This parameter is used only if one token is to be processed on CPU backend. - `-t N`, `--threads N`: Set the number of threads to use by CPU layers during generation. Not used by model layers that are offloaded to GPU. This option has no effect when using the maximum number of GPU layers. Default: `std::thread::hardware_concurrency()` (number of CPU cores).
- `-tb N, --threads-batch N`: Set the number of threads to use during batch and prompt processing. If not specified, the number of threads will be set to the number of threads used for generation. Not used if model layers are offloaded to GPU. - `-tb N, --threads-batch N`: Set the number of threads to use by CPU layers during batch and prompt processing (>= 32 tokens). This option has no effect if a GPU is available. Default: `--threads`.
- `--threads-http N`: Number of threads in the http server pool to process requests. Default: `max(std::thread::hardware_concurrency() - 1, --parallel N + 2)` - `--threads-http N`: Number of threads in the http server pool to process requests. Default: `max(std::thread::hardware_concurrency() - 1, --parallel N + 2)`
- `-m FNAME`, `--model FNAME`: Specify the path to the LLaMA model file (e.g., `models/7B/ggml-model.gguf`). - `-m FNAME`, `--model FNAME`: Specify the path to the LLaMA model file (e.g., `models/7B/ggml-model.gguf`).
- `-mu MODEL_URL --model-url MODEL_URL`: Specify a remote http url to download the file. Default: unused - `-mu MODEL_URL --model-url MODEL_URL`: Specify a remote http url to download the file. Default: unused

View file

@ -45,19 +45,59 @@ static bool g_mul_mat_q = false;
#include <mutex> #include <mutex>
#include <stdint.h> #include <stdint.h>
#include <stdio.h> #include <stdio.h>
#include <stdarg.h>
#include <stdlib.h>
#include <string> #include <string>
#include <vector> #include <vector>
static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size"); static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
static void ggml_cuda_default_log_callback(enum ggml_log_level level, const char * msg, void * user_data) {
GGML_UNUSED(level);
GGML_UNUSED(user_data);
fprintf(stderr, "%s", msg);
}
ggml_log_callback ggml_cuda_log_callback = ggml_cuda_default_log_callback;
void * ggml_cuda_log_user_data = NULL;
GGML_API void ggml_backend_cuda_log_set_callback(ggml_log_callback log_callback, void * user_data) {
ggml_cuda_log_callback = log_callback;
ggml_cuda_log_user_data = user_data;
}
#define GGML_CUDA_LOG_INFO(...) ggml_cuda_log(GGML_LOG_LEVEL_INFO, __VA_ARGS__)
#define GGML_CUDA_LOG_WARN(...) ggml_cuda_log(GGML_LOG_LEVEL_WARN, __VA_ARGS__)
#define GGML_CUDA_LOG_ERROR(...) ggml_cuda_log(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
GGML_ATTRIBUTE_FORMAT(2, 3)
static void ggml_cuda_log(enum ggml_log_level level, const char * format, ...) {
if (ggml_cuda_log_callback != NULL) {
va_list args;
va_start(args, format);
char buffer[128];
int len = vsnprintf(buffer, 128, format, args);
if (len < 128) {
ggml_cuda_log_callback(level, buffer, ggml_cuda_log_user_data);
} else {
std::vector<char> buffer2(len + 1); // vsnprintf adds a null terminator
va_end(args);
va_start(args, format);
vsnprintf(&buffer2[0], buffer2.size(), format, args);
ggml_cuda_log_callback(level, buffer2.data(), ggml_cuda_log_user_data);
}
va_end(args);
}
}
[[noreturn]] [[noreturn]]
void ggml_cuda_error(const char * stmt, const char * func, const char * file, int line, const char * msg) { void ggml_cuda_error(const char * stmt, const char * func, const char * file, int line, const char * msg) {
int id = -1; // in case cudaGetDevice fails int id = -1; // in case cudaGetDevice fails
cudaGetDevice(&id); cudaGetDevice(&id);
fprintf(stderr, "CUDA error: %s\n", msg); GGML_CUDA_LOG_ERROR("CUDA error: %s\n", msg);
fprintf(stderr, " current device: %d, in function %s at %s:%d\n", id, func, file, line); GGML_CUDA_LOG_ERROR(" current device: %d, in function %s at %s:%d\n", id, func, file, line);
fprintf(stderr, " %s\n", stmt); GGML_CUDA_LOG_ERROR(" %s\n", stmt);
// abort with GGML_ASSERT to get a stack trace // abort with GGML_ASSERT to get a stack trace
GGML_ASSERT(!"CUDA error"); GGML_ASSERT(!"CUDA error");
} }
@ -93,7 +133,7 @@ static ggml_cuda_device_info ggml_cuda_init() {
cudaError_t err = cudaGetDeviceCount(&info.device_count); cudaError_t err = cudaGetDeviceCount(&info.device_count);
if (err != cudaSuccess) { if (err != cudaSuccess) {
fprintf(stderr, "%s: failed to initialize " GGML_CUDA_NAME ": %s\n", __func__, cudaGetErrorString(err)); GGML_CUDA_LOG_ERROR("%s: failed to initialize " GGML_CUDA_NAME ": %s\n", __func__, cudaGetErrorString(err));
return info; return info;
} }
@ -101,16 +141,16 @@ static ggml_cuda_device_info ggml_cuda_init() {
int64_t total_vram = 0; int64_t total_vram = 0;
// #if defined(GGML_CUDA_FORCE_MMQ) // #if defined(GGML_CUDA_FORCE_MMQ)
// fprintf(stderr, "%s: GGML_CUDA_FORCE_MMQ: yes\n", __func__); // GGML_CUDA_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ: yes\n", __func__);
// #else // #else
// fprintf(stderr, "%s: GGML_CUDA_FORCE_MMQ: no\n", __func__); // GGML_CUDA_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ: no\n", __func__);
// #endif // #endif
// #if defined(CUDA_USE_TENSOR_CORES) // #if defined(CUDA_USE_TENSOR_CORES)
// fprintf(stderr, "%s: CUDA_USE_TENSOR_CORES: yes\n", __func__); // GGML_CUDA_LOG_INFO("%s: CUDA_USE_TENSOR_CORES: yes\n", __func__);
// #else // #else
// fprintf(stderr, "%s: CUDA_USE_TENSOR_CORES: no\n", __func__); // GGML_CUDA_LOG_INFO("%s: CUDA_USE_TENSOR_CORES: no\n", __func__);
// #endif // #endif
fprintf(stderr, "%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, info.device_count); GGML_CUDA_LOG_INFO("%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, info.device_count);
for (int id = 0; id < info.device_count; ++id) { for (int id = 0; id < info.device_count; ++id) {
int device_vmm = 0; int device_vmm = 0;
@ -131,7 +171,7 @@ static ggml_cuda_device_info ggml_cuda_init() {
cudaDeviceProp prop; cudaDeviceProp prop;
CUDA_CHECK(cudaGetDeviceProperties(&prop, id)); CUDA_CHECK(cudaGetDeviceProperties(&prop, id));
fprintf(stderr, " Device %d: %s, compute capability %d.%d, VMM: %s\n", id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no"); GGML_CUDA_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s\n", id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no");
info.default_tensor_split[id] = total_vram; info.default_tensor_split[id] = total_vram;
total_vram += prop.totalGlobalMem; total_vram += prop.totalGlobalMem;
@ -250,7 +290,7 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool {
return; return;
} }
} }
fprintf(stderr, "WARNING: cuda buffer pool full, increase MAX_CUDA_BUFFERS\n"); GGML_CUDA_LOG_WARN("Cuda buffer pool full, increase MAX_CUDA_BUFFERS\n");
ggml_cuda_set_device(device); ggml_cuda_set_device(device);
CUDA_CHECK(cudaFree(ptr)); CUDA_CHECK(cudaFree(ptr));
pool_size -= size; pool_size -= size;
@ -499,7 +539,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffe
void * dev_ptr; void * dev_ptr;
cudaError_t err = cudaMalloc(&dev_ptr, size); cudaError_t err = cudaMalloc(&dev_ptr, size);
if (err != cudaSuccess) { if (err != cudaSuccess) {
fprintf(stderr, "%s: allocating %.2f MiB on device %d: cudaMalloc failed: %s\n", __func__, size/1024.0/1024.0, buft_ctx->device, cudaGetErrorString(err)); GGML_CUDA_LOG_ERROR("%s: allocating %.2f MiB on device %d: cudaMalloc failed: %s\n", __func__, size / 1024.0 / 1024.0, buft_ctx->device, cudaGetErrorString(err));
return nullptr; return nullptr;
} }
@ -1002,7 +1042,7 @@ static void * ggml_cuda_host_malloc(size_t size) {
if (err != cudaSuccess) { if (err != cudaSuccess) {
// clear the error // clear the error
cudaGetLastError(); cudaGetLastError();
fprintf(stderr, "%s: warning: failed to allocate %.2f MiB of pinned memory: %s\n", __func__, GGML_CUDA_LOG_WARN("%s: failed to allocate %.2f MiB of pinned memory: %s\n", __func__,
size / 1024.0 / 1024.0, cudaGetErrorString(err)); size / 1024.0 / 1024.0, cudaGetErrorString(err));
return nullptr; return nullptr;
} }
@ -2252,7 +2292,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
break; break;
case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT:
if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) { if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) {
fprintf(stderr, "%s: cannot compute %s: src0->ne[3] = %" PRId64 ", src1->ne[3] = %" PRId64 " - fallback to CPU\n", __func__, dst->name, dst->src[0]->ne[3], dst->src[1]->ne[3]); GGML_CUDA_LOG_ERROR("%s: cannot compute %s: src0->ne[3] = %" PRId64 ", src1->ne[3] = %" PRId64 " - fallback to CPU\n", __func__, dst->name, dst->src[0]->ne[3], dst->src[1]->ne[3]);
return false; return false;
} else { } else {
ggml_cuda_mul_mat(ctx, dst->src[0], dst->src[1], dst); ggml_cuda_mul_mat(ctx, dst->src[0], dst->src[1], dst);
@ -2306,7 +2346,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
cudaError_t err = cudaGetLastError(); cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) { if (err != cudaSuccess) {
fprintf(stderr, "%s: %s failed\n", __func__, ggml_op_desc(dst)); GGML_CUDA_LOG_ERROR("%s: %s failed\n", __func__, ggml_op_desc(dst));
CUDA_CHECK(err); CUDA_CHECK(err);
} }
@ -2482,7 +2522,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
if (ggml_cuda_info().devices[cuda_ctx->device].cc < CC_AMPERE) { if (ggml_cuda_info().devices[cuda_ctx->device].cc < CC_AMPERE) {
cuda_ctx->cuda_graph->disable_due_to_gpu_arch = true; cuda_ctx->cuda_graph->disable_due_to_gpu_arch = true;
#ifndef NDEBUG #ifndef NDEBUG
fprintf(stderr, "%s: disabling CUDA graphs due to GPU architecture\n", __func__); GGML_CUDA_LOG_WARN("%s: disabling CUDA graphs due to GPU architecture\n", __func__);
#endif #endif
} }
} }
@ -2529,14 +2569,14 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
if (node->src[0] && ggml_backend_buffer_is_cuda_split(node->src[0]->buffer)) { if (node->src[0] && ggml_backend_buffer_is_cuda_split(node->src[0]->buffer)) {
use_cuda_graph = false; // Split buffers are not supported by CUDA graph capture use_cuda_graph = false; // Split buffers are not supported by CUDA graph capture
#ifndef NDEBUG #ifndef NDEBUG
fprintf(stderr, "%s: disabling CUDA graphs due to split buffer\n", __func__); GGML_CUDA_LOG_WARN("%s: disabling CUDA graphs due to split buffer\n", __func__);
#endif #endif
} }
if (node->op == GGML_OP_MUL_MAT_ID) { if (node->op == GGML_OP_MUL_MAT_ID) {
use_cuda_graph = false; // This node type is not supported by CUDA graph capture use_cuda_graph = false; // This node type is not supported by CUDA graph capture
#ifndef NDEBUG #ifndef NDEBUG
fprintf(stderr, "%s: disabling CUDA graphs due to mul_mat_id\n", __func__); GGML_CUDA_LOG_WARN("%s: disabling CUDA graphs due to mul_mat_id\n", __func__);
#endif #endif
} }
@ -2545,7 +2585,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
// Changes in batch size or context size can cause changes to the grid size of some kernels. // Changes in batch size or context size can cause changes to the grid size of some kernels.
use_cuda_graph = false; use_cuda_graph = false;
#ifndef NDEBUG #ifndef NDEBUG
fprintf(stderr, "%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]); GGML_CUDA_LOG_WARN("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
#endif #endif
} }
@ -2573,7 +2613,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
if (cuda_ctx->cuda_graph->number_consecutive_updates >= 4) { if (cuda_ctx->cuda_graph->number_consecutive_updates >= 4) {
cuda_ctx->cuda_graph->disable_due_to_too_many_updates = true; cuda_ctx->cuda_graph->disable_due_to_too_many_updates = true;
#ifndef NDEBUG #ifndef NDEBUG
fprintf(stderr, "%s: disabling CUDA graphs due to too many consecutive updates\n", __func__); GGML_CUDA_LOG_WARN("%s: disabling CUDA graphs due to too many consecutive updates\n", __func__);
#endif #endif
} }
} }
@ -2611,7 +2651,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
bool ok = ggml_cuda_compute_forward(*cuda_ctx, node); bool ok = ggml_cuda_compute_forward(*cuda_ctx, node);
if (!ok) { if (!ok) {
fprintf(stderr, "%s: error: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op)); GGML_CUDA_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
} }
GGML_ASSERT(ok); GGML_ASSERT(ok);
} }
@ -2630,7 +2670,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
use_cuda_graph = false; use_cuda_graph = false;
cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture = true; cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture = true;
#ifndef NDEBUG #ifndef NDEBUG
fprintf(stderr, "%s: disabling CUDA graphs due to failed graph capture\n", __func__); GGML_CUDA_LOG_WARN("%s: disabling CUDA graphs due to failed graph capture\n", __func__);
#endif #endif
} else { } else {
graph_evaluated_or_captured = true; // CUDA graph has been captured graph_evaluated_or_captured = true; // CUDA graph has been captured
@ -2697,7 +2737,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info); cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info);
if (stat == cudaErrorGraphExecUpdateFailure) { if (stat == cudaErrorGraphExecUpdateFailure) {
#ifndef NDEBUG #ifndef NDEBUG
fprintf(stderr, "%s: CUDA graph update failed\n", __func__); GGML_CUDA_LOG_ERROR("%s: CUDA graph update failed\n", __func__);
#endif #endif
// The pre-existing graph exec cannot be updated due to violated constraints // The pre-existing graph exec cannot be updated due to violated constraints
// so instead clear error and re-instantiate // so instead clear error and re-instantiate
@ -2954,13 +2994,13 @@ static ggml_guid_t ggml_backend_cuda_guid() {
GGML_CALL ggml_backend_t ggml_backend_cuda_init(int device) { GGML_CALL ggml_backend_t ggml_backend_cuda_init(int device) {
if (device < 0 || device >= ggml_backend_cuda_get_device_count()) { if (device < 0 || device >= ggml_backend_cuda_get_device_count()) {
fprintf(stderr, "%s: error: invalid device %d\n", __func__, device); GGML_CUDA_LOG_ERROR("%s: invalid device %d\n", __func__, device);
return nullptr; return nullptr;
} }
ggml_backend_cuda_context * ctx = new ggml_backend_cuda_context(device); ggml_backend_cuda_context * ctx = new ggml_backend_cuda_context(device);
if (ctx == nullptr) { if (ctx == nullptr) {
fprintf(stderr, "%s: error: failed to allocate context\n", __func__); GGML_CUDA_LOG_ERROR("%s: failed to allocate context\n", __func__);
return nullptr; return nullptr;
} }
@ -3004,7 +3044,7 @@ GGML_CALL bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size
// clear the error // clear the error
cudaGetLastError(); cudaGetLastError();
fprintf(stderr, "%s: warning: failed to register %.2f MiB of pinned memory: %s\n", __func__, GGML_CUDA_LOG_WARN("%s: failed to register %.2f MiB of pinned memory: %s\n", __func__,
size / 1024.0 / 1024.0, cudaGetErrorString(err)); size / 1024.0 / 1024.0, cudaGetErrorString(err));
return false; return false;
} }

View file

@ -39,6 +39,7 @@ GGML_API GGML_CALL void ggml_backend_cuda_get_device_memory(int device, size_t *
GGML_API GGML_CALL bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size); GGML_API GGML_CALL bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size);
GGML_API GGML_CALL void ggml_backend_cuda_unregister_host_buffer(void * buffer); GGML_API GGML_CALL void ggml_backend_cuda_unregister_host_buffer(void * buffer);
GGML_API void ggml_backend_cuda_log_set_callback(ggml_log_callback log_callback, void * user_data);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif

View file

@ -315,6 +315,20 @@ static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
#endif #endif
return c; return c;
} }
#if defined(__HIP_PLATFORM_AMD__) && HIP_VERSION < 50600000
// __shfl_xor() for half2 was added in ROCm 5.6
static __device__ __forceinline__ half2 __shfl_xor(half2 var, int laneMask, int width) {
typedef union half2_b32 {
half2 val;
int b32;
} half2_b32_t;
half2_b32_t tmp;
tmp.val = var;
tmp.b32 = __shfl_xor(tmp.b32, laneMask, width);
return tmp.val;
}
#endif // defined(__HIP_PLATFORM_AMD__) && HIP_VERSION < 50600000
#endif // defined(GGML_USE_HIPBLAS) #endif // defined(GGML_USE_HIPBLAS)
#define FP16_AVAILABLE (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL #define FP16_AVAILABLE (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL
@ -463,6 +477,17 @@ static const __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -
typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v); typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v);
static __device__ __forceinline__ float get_alibi_slope(
const float max_bias, const uint32_t h, const uint32_t n_head_log2, const float m0, const float m1
) {
if (max_bias <= 0.0f) {
return 1.0f;
}
const float base = h < n_head_log2 ? m0 : m1;
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
return powf(base, exph);
}
////////////////////// //////////////////////

View file

@ -1,7 +1,44 @@
#include "common.cuh"
#include <cstdint>
#define FATTN_KQ_STRIDE 256 #define FATTN_KQ_STRIDE 256
#define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction. #define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
#define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs. #define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs.
typedef void (* fattn_kernel_t)(
const char * __restrict__ Q,
const char * __restrict__ K,
const char * __restrict__ V,
const char * __restrict__ mask,
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
const float scale,
const float max_bias,
const float m0,
const float m1,
const uint32_t n_head_log2,
const int ne00,
const int ne01,
const int ne02,
const int ne03,
const int ne10,
const int ne11,
const int ne12,
const int ne13,
const int ne31,
const int nb31,
const int nb01,
const int nb02,
const int nb03,
const int nb11,
const int nb12,
const int nb13,
const int ne0,
const int ne1,
const int ne2,
const int ne3);
template<int D, int parallel_blocks> // D == head size template<int D, int parallel_blocks> // D == head size
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__(D, 1) __launch_bounds__(D, 1)
@ -45,3 +82,81 @@ static __global__ void flash_attn_combine_results(
dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator; dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator;
} }
template <int D, int parallel_blocks>
void launch_fattn(ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, int nwarps, int cols_per_block) {
const ggml_tensor * Q = dst->src[0];
const ggml_tensor * K = dst->src[1];
const ggml_tensor * V = dst->src[2];
const ggml_tensor * mask = dst->src[3];
ggml_tensor * KQV = dst;
GGML_ASSERT(Q->type == GGML_TYPE_F32);
GGML_ASSERT(K->type == GGML_TYPE_F16);
GGML_ASSERT(V->type == GGML_TYPE_F16);
GGML_ASSERT(KQV->type == GGML_TYPE_F32);
GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
"the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
ggml_cuda_pool & pool = ctx.pool();
cudaStream_t main_stream = ctx.stream();
ggml_cuda_pool_alloc<float> dst_tmp(pool);
ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
if (parallel_blocks > 1) {
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
}
const dim3 block_dim(WARP_SIZE, nwarps, 1);
const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]);
const int shmem = 0;
float scale = 1.0f;
float max_bias = 0.0f;
memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
const uint32_t n_head = Q->ne[2];
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
fattn_kernel<<<blocks_num, block_dim, shmem, main_stream>>>(
(const char *) Q->data,
(const char *) K->data,
(const char *) V->data,
mask ? ((const char *) mask->data) : nullptr,
(parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
scale, max_bias, m0, m1, n_head_log2,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
Q->nb[1], Q->nb[2], Q->nb[3],
K->nb[1], K->nb[2], K->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
);
CUDA_CHECK(cudaGetLastError());
if ((parallel_blocks) == 1) {
return;
}
const dim3 block_dim_combine(D, 1, 1);
const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
const int shmem_combine = 0;
flash_attn_combine_results<D, parallel_blocks>
<<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
CUDA_CHECK(cudaGetLastError());
}

312
ggml-cuda/fattn-tile-f16.cu Normal file
View file

@ -0,0 +1,312 @@
#include "common.cuh"
#include "fattn-common.cuh"
#include "fattn-tile-f16.cuh"
#define FATTN_KQ_STRIDE_TILE_F16 64
template<int D, int ncols, int nwarps, int parallel_blocks> // D == head size
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__(nwarps*WARP_SIZE, 1)
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
static __global__ void flash_attn_tile_ext_f16(
const char * __restrict__ Q,
const char * __restrict__ K,
const char * __restrict__ V,
const char * __restrict__ mask,
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
const float scale,
const float max_bias,
const float m0,
const float m1,
const uint32_t n_head_log2,
const int ne00,
const int ne01,
const int ne02,
const int ne03,
const int ne10,
const int ne11,
const int ne12,
const int ne13,
const int ne31,
const int nb31,
const int nb01,
const int nb02,
const int nb03,
const int nb11,
const int nb12,
const int nb13,
const int ne0,
const int ne1,
const int ne2,
const int ne3) {
#if FP16_AVAILABLE
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic0);
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio));
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
const half * maskh = (const half *) mask + ne11*ic0;
const int stride_KV2 = nb11 / sizeof(half2);
const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
const half slopeh = __float2half(slopef);
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
__shared__ half KQ[ncols*FATTN_KQ_STRIDE_TILE_F16];
half2 * KQ2 = (half2 *) KQ;
__shared__ half2 KV_tmp[FATTN_KQ_STRIDE_TILE_F16][D/2 + 1]; // Pad D to avoid memory bank conflicts.
half kqmax[ncols/nwarps];
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
kqmax[j0/nwarps] = -HALF_MAX_HALF;
}
half2 kqsum[ncols/nwarps] = {{0.0f, 0.0f}};
half2 VKQ[ncols/nwarps][(D/2)/WARP_SIZE] = {{{0.0f, 0.0f}}};
// Convert Q to half2 and store in registers:
__shared__ half2 Q_h2[ncols][D/2];
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j = j0 + threadIdx.y;
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
const float2 tmp = Q_f2[j*(nb01/sizeof(float2)) + i];
Q_h2[j][i] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y);
}
}
__syncthreads();
const int k_start = parallel_blocks == 1 ? 0 : ip*FATTN_KQ_STRIDE_TILE_F16;
for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE_TILE_F16) {
// Calculate KQ tile and keep track of new maximum KQ values:
half kqmax_new[ncols/nwarps];
#pragma unroll
for (int j = 0; j < ncols/nwarps; ++j) {
kqmax_new[j] = kqmax[j];
}
#pragma unroll
for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F16; i_KQ_0 += nwarps) {
const int i_KQ = i_KQ_0 + threadIdx.y;
#pragma unroll
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
const int k_KQ = k_KQ_0 + threadIdx.x;
KV_tmp[i_KQ][k_KQ] = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ];
}
}
__syncthreads();
half2 sum2[FATTN_KQ_STRIDE_TILE_F16/WARP_SIZE][ncols/nwarps] = {{{0.0f, 0.0f}}};
#pragma unroll
for (int k_KQ = 0; k_KQ < D/2; ++k_KQ) {
half2 K_k[FATTN_KQ_STRIDE_TILE_F16/WARP_SIZE];
half2 Q_k[ncols/nwarps];
#pragma unroll
for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F16; i_KQ_0 += WARP_SIZE) {
const int i_KQ = i_KQ_0 + threadIdx.x;
K_k[i_KQ_0/WARP_SIZE] = KV_tmp[i_KQ][k_KQ];
}
#pragma unroll
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
const int j_KQ = j_KQ_0 + threadIdx.y;
Q_k[j_KQ_0/nwarps] = Q_h2[j_KQ][k_KQ];
}
#pragma unroll
for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F16; i_KQ_0 += WARP_SIZE) {
#pragma unroll
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] += K_k[i_KQ_0/WARP_SIZE]*Q_k[j_KQ_0/nwarps];
}
}
}
#pragma unroll
for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F16; i_KQ_0 += WARP_SIZE) {
const int i_KQ = i_KQ_0 + threadIdx.x;
#pragma unroll
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
const int j_KQ = j_KQ_0 + threadIdx.y;
half sum = __low2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]) + __high2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
sum += mask ? slopeh*maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
kqmax_new[j_KQ_0/nwarps] = ggml_cuda_hmax(kqmax_new[j_KQ_0/nwarps], sum);
KQ[j_KQ*FATTN_KQ_STRIDE_TILE_F16 + i_KQ] = sum;
}
}
__syncthreads();
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j = j0 + threadIdx.y;
kqmax_new[j0/nwarps] = warp_reduce_max(kqmax_new[j0/nwarps]);
const half2 KQ_max_scale = __half2half2(hexp(kqmax[j0/nwarps] - kqmax_new[j0/nwarps]));
kqmax[j0/nwarps] = kqmax_new[j0/nwarps];
#pragma unroll
for (int i0 = 0; i0 < FATTN_KQ_STRIDE_TILE_F16/2; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
const half2 diff = KQ2[j*(FATTN_KQ_STRIDE_TILE_F16/2) + i] - __half2half2(kqmax[j0/nwarps]);
const half2 val = h2exp(diff);
kqsum[j0/nwarps] = kqsum[j0/nwarps]*KQ_max_scale + val;
KQ2[j*(FATTN_KQ_STRIDE_TILE_F16/2) + i] = val;
}
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
VKQ[j0/nwarps][i0/WARP_SIZE] *= KQ_max_scale;
}
}
__syncthreads();
#pragma unroll
for (int k0 = 0; k0 < FATTN_KQ_STRIDE_TILE_F16; k0 += nwarps) {
const int k = k0 + threadIdx.y;
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
KV_tmp[k][i] = V_h2[(k_VKQ_0 + k)*stride_KV2 + i];
}
}
__syncthreads();
#pragma unroll
for (int k0 = 0; k0 < FATTN_KQ_STRIDE_TILE_F16; k0 += 2) {
half2 V_k[(D/2)/WARP_SIZE][2];
half2 KQ_k[ncols/nwarps];
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
V_k[i0/WARP_SIZE][0] = KV_tmp[k0 + 0][i];
V_k[i0/WARP_SIZE][1] = KV_tmp[k0 + 1][i];
}
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j = j0 + threadIdx.y;
KQ_k[j0/nwarps] = KQ2[j*(FATTN_KQ_STRIDE_TILE_F16/2) + k0/2];
}
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
VKQ[j0/nwarps][i0/WARP_SIZE] += V_k[i0/WARP_SIZE][0]* __low2half2(KQ_k[j0/nwarps]);
VKQ[j0/nwarps][i0/WARP_SIZE] += V_k[i0/WARP_SIZE][1]*__high2half2(KQ_k[j0/nwarps]);
}
}
}
__syncthreads();
}
#pragma unroll
for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
const int j_VKQ = j_VKQ_0 + threadIdx.y;
half kqsum_j = __low2half(kqsum[j_VKQ_0/nwarps]) + __high2half(kqsum[j_VKQ_0/nwarps]);
kqsum_j = warp_reduce_sum(kqsum_j);
#pragma unroll
for (int i00 = 0; i00 < D; i00 += 2*WARP_SIZE) {
const int i0 = i00 + 2*threadIdx.x;
half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)];
if (parallel_blocks == 1) {
dst_val /= __half2half2(kqsum_j);
}
const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 0] = __low2float(dst_val);
dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 1] = __high2float(dst_val);
}
if (parallel_blocks != 1 && threadIdx.x == 0) {
dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
}
}
#else
NO_DEVICE_CODE;
#endif // FP16_AVAILABLE
}
template <int cols_per_block, int parallel_blocks>
void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * Q = dst->src[0];
switch (Q->ne[0]) {
case 64: {
constexpr int D = 64;
constexpr int nwarps = 8;
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
} break;
case 128: {
constexpr int D = 128;
constexpr int nwarps = 8;
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
} break;
default: {
GGML_ASSERT(false && "FlashAttention without tensor cores only supports head sizes 64 and 128.");
} break;
}
}
void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * KQV = dst;
const ggml_tensor * Q = dst->src[0];
const int32_t precision = KQV->op_params[2];
GGML_ASSERT(precision == GGML_PREC_DEFAULT);
if (Q->ne[1] <= 16) {
constexpr int cols_per_block = 16;
constexpr int parallel_blocks = 4;
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
return;
}
if (Q->ne[1] <= 32) {
constexpr int cols_per_block = 32;
constexpr int parallel_blocks = 4;
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
return;
}
constexpr int cols_per_block = 32;
constexpr int parallel_blocks = 1;
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
}

View file

@ -0,0 +1,3 @@
#include "common.cuh"
void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

309
ggml-cuda/fattn-tile-f32.cu Normal file
View file

@ -0,0 +1,309 @@
#include "common.cuh"
#include "fattn-common.cuh"
#include "fattn-tile-f32.cuh"
#define FATTN_KQ_STRIDE_TILE_F32 32
template<int D, int ncols, int nwarps, int parallel_blocks> // D == head size
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__(nwarps*WARP_SIZE, 1)
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
static __global__ void flash_attn_tile_ext_f32(
const char * __restrict__ Q,
const char * __restrict__ K,
const char * __restrict__ V,
const char * __restrict__ mask,
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
const float scale,
const float max_bias,
const float m0,
const float m1,
const uint32_t n_head_log2,
const int ne00,
const int ne01,
const int ne02,
const int ne03,
const int ne10,
const int ne11,
const int ne12,
const int ne13,
const int ne31,
const int nb31,
const int nb01,
const int nb02,
const int nb03,
const int nb11,
const int nb12,
const int nb13,
const int ne0,
const int ne1,
const int ne2,
const int ne3) {
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic0);
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio));
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
const half * maskh = (const half *) mask + ne11*ic0;
const int stride_KV2 = nb11 / sizeof(half2);
const float slope = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
__shared__ float KQ[ncols*FATTN_KQ_STRIDE_TILE_F32];
__shared__ float KV_tmp[FATTN_KQ_STRIDE_TILE_F32][D + 1]; // Pad D to avoid memory bank conflicts.
float2 * KV_tmp2 = (float2 *) KV_tmp;
float kqmax[ncols/nwarps];
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
kqmax[j0/nwarps] = -FLT_MAX/2.0f;
}
float kqsum[ncols/nwarps] = {0.0f};
float2 VKQ[ncols/nwarps][(D/2)/WARP_SIZE] = {{{0.0f, 0.0f}}};
// Convert Q to half2 and store in registers:
__shared__ float Q_f[ncols][D];
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j = j0 + threadIdx.y;
#pragma unroll
for (int i0 = 0; i0 < D; i0 += 2*WARP_SIZE) {
float2 tmp = Q_f2[j*(nb01/sizeof(float2)) + i0/2 + threadIdx.x];
Q_f[j][i0 + 0*WARP_SIZE + threadIdx.x] = tmp.x * scale;
Q_f[j][i0 + 1*WARP_SIZE + threadIdx.x] = tmp.y * scale;
}
}
__syncthreads();
const int k_start = parallel_blocks == 1 ? 0 : ip*FATTN_KQ_STRIDE_TILE_F32;
for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE_TILE_F32) {
// Calculate KQ tile and keep track of new maximum KQ values:
float kqmax_new[ncols/nwarps];
#pragma unroll
for (int j = 0; j < ncols/nwarps; ++j) {
kqmax_new[j] = kqmax[j];
}
#pragma unroll
for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F32; i_KQ_0 += nwarps) {
const int i_KQ = i_KQ_0 + threadIdx.y;
#pragma unroll
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 2*WARP_SIZE) {
const half2 tmp = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + threadIdx.x];
KV_tmp[i_KQ][k_KQ_0 + 0*WARP_SIZE + threadIdx.x] = __low2float(tmp);
KV_tmp[i_KQ][k_KQ_0 + 1*WARP_SIZE + threadIdx.x] = __high2float(tmp);
}
}
__syncthreads();
float sum[FATTN_KQ_STRIDE_TILE_F32/WARP_SIZE][ncols/nwarps] = {{0.0f}};
#pragma unroll
for (int k_KQ = 0; k_KQ < D; ++k_KQ) {
float K_k[FATTN_KQ_STRIDE_TILE_F32/WARP_SIZE];
float Q_k[ncols/nwarps];
#pragma unroll
for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F32; i_KQ_0 += WARP_SIZE) {
const int i_KQ = i_KQ_0 + threadIdx.x;
K_k[i_KQ_0/WARP_SIZE] = KV_tmp[i_KQ][k_KQ];
}
#pragma unroll
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
const int j_KQ = j_KQ_0 + threadIdx.y;
Q_k[j_KQ_0/nwarps] = Q_f[j_KQ][k_KQ];
}
#pragma unroll
for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F32; i_KQ_0 += WARP_SIZE) {
#pragma unroll
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] += K_k[i_KQ_0/WARP_SIZE] * Q_k[j_KQ_0/nwarps];
}
}
}
#pragma unroll
for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F32; i_KQ_0 += WARP_SIZE) {
const int i_KQ = i_KQ_0 + threadIdx.x;
#pragma unroll
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
const int j_KQ = j_KQ_0 + threadIdx.y;
sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] += mask ? slope*__half2float(maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
kqmax_new[j_KQ_0/nwarps] = fmaxf(kqmax_new[j_KQ_0/nwarps], sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
KQ[j_KQ*FATTN_KQ_STRIDE_TILE_F32 + i_KQ] = sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps];
}
}
__syncthreads();
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j = j0 + threadIdx.y;
kqmax_new[j0/nwarps] = warp_reduce_max(kqmax_new[j0/nwarps]);
const float KQ_max_scale = expf(kqmax[j0/nwarps] - kqmax_new[j0/nwarps]);
kqmax[j0/nwarps] = kqmax_new[j0/nwarps];
float kqsum_add = 0.0f;
#pragma unroll
for (int i0 = 0; i0 < FATTN_KQ_STRIDE_TILE_F32; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
const float diff = KQ[j*FATTN_KQ_STRIDE_TILE_F32 + i] - kqmax[j0/nwarps];
const float val = expf(diff);
kqsum_add += val;
KQ[j*FATTN_KQ_STRIDE_TILE_F32 + i] = val;
}
kqsum[j0/nwarps] = kqsum[j0/nwarps]*KQ_max_scale + kqsum_add;
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
VKQ[j0/nwarps][i0/WARP_SIZE].x *= KQ_max_scale;
VKQ[j0/nwarps][i0/WARP_SIZE].y *= KQ_max_scale;
}
}
__syncthreads();
#pragma unroll
for (int k0 = 0; k0 < FATTN_KQ_STRIDE_TILE_F32; k0 += nwarps) {
const int k = k0 + threadIdx.y;
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
KV_tmp2[k*(D/2) + i].x = __low2float(V_h2[(k_VKQ_0 + k)*stride_KV2 + i]);
KV_tmp2[k*(D/2) + i].y = __high2float(V_h2[(k_VKQ_0 + k)*stride_KV2 + i]);
}
}
__syncthreads();
#pragma unroll
for (int k = 0; k < FATTN_KQ_STRIDE_TILE_F32; ++k) {
float2 V_k[(D/2)/WARP_SIZE];
float KQ_k[ncols/nwarps];
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
V_k[i0/WARP_SIZE] = KV_tmp2[k*(D/2) + i];
}
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j = j0 + threadIdx.y;
KQ_k[j0/nwarps] = KQ[j*FATTN_KQ_STRIDE_TILE_F32 + k];
}
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
VKQ[j0/nwarps][i0/WARP_SIZE].x += V_k[i0/WARP_SIZE].x*KQ_k[j0/nwarps];
VKQ[j0/nwarps][i0/WARP_SIZE].y += V_k[i0/WARP_SIZE].y*KQ_k[j0/nwarps];
}
}
}
__syncthreads();
}
#pragma unroll
for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
const int j_VKQ = j_VKQ_0 + threadIdx.y;
float kqsum_j = kqsum[j_VKQ_0/nwarps];
kqsum_j = warp_reduce_sum(kqsum_j);
#pragma unroll
for (int i00 = 0; i00 < D; i00 += 2*WARP_SIZE) {
const int i0 = i00 + 2*threadIdx.x;
float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)];
if (parallel_blocks == 1) {
dst_val.x /= kqsum_j;
dst_val.y /= kqsum_j;
}
const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 0] = dst_val.x;
dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 1] = dst_val.y;
}
if (parallel_blocks != 1 && threadIdx.x == 0) {
dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
}
}
}
template <int cols_per_block, int parallel_blocks>
void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * Q = dst->src[0];
switch (Q->ne[0]) {
case 64: {
constexpr int D = 64;
constexpr int nwarps = 8;
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks>;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
} break;
case 128: {
constexpr int D = 128;
constexpr int nwarps = 8;
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks>;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
} break;
default: {
GGML_ASSERT(false && "FlashAttention without tensor cores only supports head sizes 64 and 128.");
} break;
}
}
void ggml_cuda_flash_attn_ext_tile_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * KQV = dst;
const ggml_tensor * Q = dst->src[0];
const int32_t precision = KQV->op_params[2];
GGML_ASSERT(precision == GGML_PREC_DEFAULT);
if (Q->ne[1] <= 16) {
constexpr int cols_per_block = 16;
constexpr int parallel_blocks = 4;
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks>(ctx, dst);
return;
}
if (Q->ne[1] <= 32) {
constexpr int cols_per_block = 32;
constexpr int parallel_blocks = 4;
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks>(ctx, dst);
return;
}
constexpr int cols_per_block = 32;
constexpr int parallel_blocks = 1;
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks>(ctx, dst);
}

View file

@ -0,0 +1,3 @@
#include "common.cuh"
void ggml_cuda_flash_attn_ext_tile_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View file

@ -53,17 +53,8 @@ static __global__ void flash_attn_vec_ext_f16(
const int stride_KV = nb11 / sizeof(half); const int stride_KV = nb11 / sizeof(half);
const int stride_KV2 = nb11 / sizeof(half2); const int stride_KV2 = nb11 / sizeof(half2);
half slopeh = __float2half(1.0f); const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
const half slopeh = __float2half(slopef);
// ALiBi
if (max_bias > 0.0f) {
const int h = blockIdx.y;
const float base = h < n_head_log2 ? m0 : m1;
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
slopeh = __float2half(powf(base, exph));
}
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
constexpr int nwarps = D / WARP_SIZE; constexpr int nwarps = D / WARP_SIZE;
@ -232,85 +223,17 @@ static __global__ void flash_attn_vec_ext_f16(
dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val; dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
} }
if (parallel_blocks != 1 && tid != 0) { if (parallel_blocks != 1 && tid < ncols) {
#pragma unroll dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]);
for (int j = 0; j < ncols; ++j) {
dst_meta[(ic0 + j)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[j], kqsum[j]);
}
} }
#else #else
NO_DEVICE_CODE; NO_DEVICE_CODE;
#endif // FP16_AVAILABLE #endif // FP16_AVAILABLE
} }
template <int D, int cols_per_block, int parallel_blocks> void launch_fattn_vec_f16(
const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask,
ggml_cuda_pool & pool, cudaStream_t main_stream
) {
ggml_cuda_pool_alloc<float> dst_tmp(pool);
ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
if (parallel_blocks > 1) {
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
}
constexpr int nwarps = (D + WARP_SIZE - 1) / WARP_SIZE;
const dim3 block_dim(WARP_SIZE, nwarps, 1);
const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]);
const int shmem = 0;
float scale = 1.0f;
float max_bias = 0.0f;
memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
const uint32_t n_head = Q->ne[2];
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks>
<<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) Q->data,
(const char *) K->data,
(const char *) V->data,
mask ? ((const char *) mask->data) : nullptr,
parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
scale, max_bias, m0, m1, n_head_log2,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
Q->nb[1], Q->nb[2], Q->nb[3],
K->nb[1], K->nb[2], K->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
);
CUDA_CHECK(cudaGetLastError());
if (parallel_blocks == 1) {
return;
}
const dim3 block_dim_combine(D, 1, 1);
const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
const int shmem_combine = 0;
flash_attn_combine_results<D, parallel_blocks>
<<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
CUDA_CHECK(cudaGetLastError());
}
void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * Q = dst->src[0];
const ggml_tensor * K = dst->src[1];
const ggml_tensor * V = dst->src[2];
const ggml_tensor * mask = dst->src[3];
ggml_tensor * KQV = dst; ggml_tensor * KQV = dst;
ggml_tensor * Q = dst->src[0];
const int32_t precision = KQV->op_params[2]; const int32_t precision = KQV->op_params[2];
GGML_ASSERT(precision == GGML_PREC_DEFAULT); GGML_ASSERT(precision == GGML_PREC_DEFAULT);
@ -318,113 +241,86 @@ void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, ggml_tens
constexpr int cols_per_block = 1; constexpr int cols_per_block = 1;
constexpr int parallel_blocks = 4; constexpr int parallel_blocks = 4;
switch (Q->ne[0]) { switch (Q->ne[0]) {
case 64: case 64: {
launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); constexpr int D = 64;
break; constexpr int nwarps = D/WARP_SIZE;
case 128: fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks>;
launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
break; } break;
case 256: case 128: {
launch_fattn_vec_f16<256, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); constexpr int D = 128;
break; constexpr int nwarps = D/WARP_SIZE;
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks>;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
} break;
case 256: {
constexpr int D = 256;
constexpr int nwarps = D/WARP_SIZE;
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks>;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
} break;
default: default:
GGML_ASSERT(false); GGML_ASSERT(false);
break; break;
} }
} }
template <int cols_per_block, int parallel_blocks>
void launch_fattn_vec_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * Q = dst->src[0];
switch (Q->ne[0]) {
case 64: {
constexpr int D = 64;
constexpr int nwarps = D/WARP_SIZE;
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks>;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
} break;
case 128: {
constexpr int D = 128;
constexpr int nwarps = D/WARP_SIZE;
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks>;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
} break;
default: {
GGML_ASSERT(false && "FlashAttention without tensor cores only supports head sizes 64 and 128.");
} break;
}
}
void ggml_cuda_flash_attn_ext_vec_f16_no_mma(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void ggml_cuda_flash_attn_ext_vec_f16_no_mma(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * KQV = dst;
const ggml_tensor * Q = dst->src[0]; const ggml_tensor * Q = dst->src[0];
const ggml_tensor * K = dst->src[1];
const ggml_tensor * V = dst->src[2];
const ggml_tensor * mask = dst->src[3];
ggml_tensor * KQV = dst;
const int32_t precision = KQV->op_params[2]; const int32_t precision = KQV->op_params[2];
GGML_ASSERT(precision == GGML_PREC_DEFAULT); GGML_ASSERT(precision == GGML_PREC_DEFAULT);
GGML_ASSERT(Q->ne[0] == 64 || Q->ne[0] == 128 && "FlashAttention without tensor cores only supports head sizes 64 and 128.");
if (Q->ne[1] == 1) { if (Q->ne[1] == 1) {
constexpr int cols_per_block = 1; ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
constexpr int parallel_blocks = 4;
switch (Q->ne[0]) {
case 64:
launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
break;
case 128:
launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
break;
default:
GGML_ASSERT(false);
break;
}
return; return;
} }
if (Q->ne[1] == 2) { if (Q->ne[1] == 2) {
constexpr int cols_per_block = 2; constexpr int cols_per_block = 2;
constexpr int parallel_blocks = 4; constexpr int parallel_blocks = 4;
switch (Q->ne[0]) { launch_fattn_vec_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
case 64:
launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
break;
case 128:
launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
break;
default:
GGML_ASSERT(false);
break;
}
return; return;
} }
if (Q->ne[1] <= 4) { if (Q->ne[1] <= 4) {
constexpr int cols_per_block = 4; constexpr int cols_per_block = 4;
constexpr int parallel_blocks = 4; constexpr int parallel_blocks = 4;
switch (Q->ne[0]) { launch_fattn_vec_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
case 64:
launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
break;
case 128:
launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
break;
default:
GGML_ASSERT(false);
break;
}
return; return;
} }
if (Q->ne[1] <= 8) { if (Q->ne[1] <= 8) {
constexpr int cols_per_block = 8; constexpr int cols_per_block = 8;
constexpr int parallel_blocks = 4; constexpr int parallel_blocks = 4;
switch (Q->ne[0]) { launch_fattn_vec_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
case 64:
launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
break;
case 128:
launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
break;
default:
GGML_ASSERT(false);
break;
}
return; return;
} }
constexpr int cols_per_block = 8; constexpr int cols_per_block = 8;
constexpr int parallel_blocks = 1; constexpr int parallel_blocks = 1;
switch (Q->ne[0]) { launch_fattn_vec_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
case 64:
launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
break;
case 128:
launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
break;
default:
GGML_ASSERT(false);
break;
}
} }

View file

@ -52,17 +52,7 @@ static __global__ void flash_attn_vec_ext_f32(
const int stride_KV = nb11 / sizeof(half); const int stride_KV = nb11 / sizeof(half);
const int stride_KV2 = nb11 / sizeof(half2); const int stride_KV2 = nb11 / sizeof(half2);
float slope = 1.0f; const float slope = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
// ALiBi
if (max_bias > 0.0f) {
const int h = blockIdx.y;
const float base = h < n_head_log2 ? m0 : m1;
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
slope = powf(base, exph);
}
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
constexpr int nwarps = D / WARP_SIZE; constexpr int nwarps = D / WARP_SIZE;
@ -221,164 +211,65 @@ static __global__ void flash_attn_vec_ext_f32(
dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val; dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
} }
if (parallel_blocks != 1 && tid != 0) { if (parallel_blocks != 1 && tid < ncols) {
#pragma unroll dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]);
for (int j = 0; j < ncols; ++j) {
dst_meta[(ic0 + j)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[j], kqsum[j]);
} }
} }
template <int cols_per_block, int parallel_blocks>
void launch_fattn_vec_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * Q = dst->src[0];
switch (Q->ne[0]) {
case 64: {
constexpr int D = 64;
constexpr int nwarps = D/WARP_SIZE;
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<D, cols_per_block, parallel_blocks>;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
} break;
case 128: {
constexpr int D = 128;
constexpr int nwarps = D/WARP_SIZE;
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<D, cols_per_block, parallel_blocks>;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
} break;
default: {
GGML_ASSERT(false && "FlashAttention without tensor cores only supports head sizes 64 and 128.");
} break;
} }
template <int D, int cols_per_block, int parallel_blocks> void launch_fattn_vec_f32(
const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask,
ggml_cuda_pool & pool, cudaStream_t main_stream
) {
ggml_cuda_pool_alloc<float> dst_tmp(pool);
ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
if (parallel_blocks > 1) {
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
}
constexpr int nwarps = (D + WARP_SIZE - 1) / WARP_SIZE;
const dim3 block_dim(WARP_SIZE, nwarps, 1);
const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]);
const int shmem = 0;
float scale = 1.0f;
float max_bias = 0.0f;
memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
const uint32_t n_head = Q->ne[2];
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
flash_attn_vec_ext_f32<D, cols_per_block, parallel_blocks>
<<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) Q->data,
(const char *) K->data,
(const char *) V->data,
mask ? ((const char *) mask->data) : nullptr,
parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
scale, max_bias, m0, m1, n_head_log2,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
Q->nb[1], Q->nb[2], Q->nb[3],
K->nb[1], K->nb[2], K->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
);
CUDA_CHECK(cudaGetLastError());
if (parallel_blocks == 1) {
return;
}
const dim3 block_dim_combine(D, 1, 1);
const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
const int shmem_combine = 0;
flash_attn_combine_results<D, parallel_blocks>
<<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
CUDA_CHECK(cudaGetLastError());
} }
void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * Q = dst->src[0]; const ggml_tensor * Q = dst->src[0];
const ggml_tensor * K = dst->src[1];
const ggml_tensor * V = dst->src[2];
const ggml_tensor * mask = dst->src[3];
ggml_tensor * KQV = dst;
GGML_ASSERT(Q->ne[0] == 64 || Q->ne[0] == 128 && "FlashAttention without tensor cores only supports head sizes 64 and 128.");
if (Q->ne[1] == 1) { if (Q->ne[1] == 1) {
constexpr int cols_per_block = 1; constexpr int cols_per_block = 1;
constexpr int parallel_blocks = 4; constexpr int parallel_blocks = 4;
switch (Q->ne[0]) { launch_fattn_vec_f32_64_128<cols_per_block, parallel_blocks>(ctx, dst);
case 64:
launch_fattn_vec_f32< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
break;
case 128:
launch_fattn_vec_f32<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
break;
default:
GGML_ASSERT(false);
break;
}
return; return;
} }
if (Q->ne[1] == 2) { if (Q->ne[1] == 2) {
constexpr int cols_per_block = 2; constexpr int cols_per_block = 2;
constexpr int parallel_blocks = 4; constexpr int parallel_blocks = 4;
switch (Q->ne[0]) { launch_fattn_vec_f32_64_128<cols_per_block, parallel_blocks>(ctx, dst);
case 64:
launch_fattn_vec_f32< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
break;
case 128:
launch_fattn_vec_f32<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
break;
default:
GGML_ASSERT(false);
break;
}
return; return;
} }
if (Q->ne[1] <= 4) { if (Q->ne[1] <= 4) {
constexpr int cols_per_block = 4; constexpr int cols_per_block = 4;
constexpr int parallel_blocks = 4; constexpr int parallel_blocks = 4;
switch (Q->ne[0]) { launch_fattn_vec_f32_64_128<cols_per_block, parallel_blocks>(ctx, dst);
case 64:
launch_fattn_vec_f32< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
break;
case 128:
launch_fattn_vec_f32<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
break;
default:
GGML_ASSERT(false);
break;
}
return; return;
} }
if (Q->ne[1] <= 8) { if (Q->ne[1] <= 8) {
constexpr int cols_per_block = 8; constexpr int cols_per_block = 8;
constexpr int parallel_blocks = 4; constexpr int parallel_blocks = 4;
switch (Q->ne[0]) { launch_fattn_vec_f32_64_128<cols_per_block, parallel_blocks>(ctx, dst);
case 64:
launch_fattn_vec_f32< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
break;
case 128:
launch_fattn_vec_f32<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
break;
default:
GGML_ASSERT(false);
break;
}
return; return;
} }
constexpr int cols_per_block = 8; constexpr int cols_per_block = 8;
constexpr int parallel_blocks = 1; constexpr int parallel_blocks = 1;
switch (Q->ne[0]) { launch_fattn_vec_f32_64_128<cols_per_block, parallel_blocks>(ctx, dst);
case 64:
launch_fattn_vec_f32< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
break;
case 128:
launch_fattn_vec_f32<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
break;
default:
GGML_ASSERT(false);
break;
}
} }

View file

@ -1,5 +1,7 @@
#include "common.cuh" #include "common.cuh"
#include "fattn-common.cuh" #include "fattn-common.cuh"
#include "fattn-tile-f16.cuh"
#include "fattn-tile-f32.cuh"
#include "fattn-vec-f16.cuh" #include "fattn-vec-f16.cuh"
#include "fattn-vec-f32.cuh" #include "fattn-vec-f32.cuh"
#include "fattn.cuh" #include "fattn.cuh"
@ -83,19 +85,9 @@ static __global__ void flash_attn_ext_f16(
const int stride_Q = nb01 / sizeof(float); const int stride_Q = nb01 / sizeof(float);
const int stride_KV = nb11 / sizeof(half); const int stride_KV = nb11 / sizeof(half);
half slopeh = __float2half(1.0f); const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
half2 slope2 = make_half2(1.0f, 1.0f); const half slopeh = __float2half(slopef);
const half2 slope2 = make_half2(slopef, slopef);
// ALiBi
if (max_bias > 0.0f) {
const int h = blockIdx.y;
const float base = h < n_head_log2 ? m0 : m1;
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
slopeh = __float2half(powf(base, exph));
slope2 = make_half2(slopeh, slopeh);
}
frag_b Q_b[D/16][ncols/frag_n]; frag_b Q_b[D/16][ncols/frag_n];
@ -437,117 +429,64 @@ static_assert(get_VKQ_stride( 80, 1, 16) == 16, "Test failed.");
static_assert(get_VKQ_stride( 80, 2, 16) == 16, "Test failed."); static_assert(get_VKQ_stride( 80, 2, 16) == 16, "Test failed.");
static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed."); static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed.");
template <int D, int cols_per_block, int nwarps, int parallel_blocks, typename KQ_acc_t> void launch_fattn_f16_impl( template <int D, int cols_per_block, int nwarps, typename KQ_acc_t>
const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask, void launch_fattn_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_pool & pool, cudaStream_t main_stream const ggml_tensor * Q = dst->src[0];
) {
ggml_cuda_pool_alloc<float> dst_tmp(pool);
ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
if (parallel_blocks > 1) { constexpr int frag_m = cols_per_block == 8 && D % 32 == 0 ? 32 : 16;
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
}
constexpr int frag_m = (cols_per_block) == 8 && (D) % 32 == 0 ? 32 : 16;
const dim3 block_dim(WARP_SIZE, nwarps, 1);
const dim3 blocks_num(parallel_blocks*(Q->ne[1] + cols_per_block - 1) / cols_per_block, Q->ne[2], Q->ne[3]);
const int shmem = 0;
float scale = 1.0f;
float max_bias = 0.0f;
memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
const uint32_t n_head = Q->ne[2];
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>
<<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) Q->data,
(const char *) K->data,
(const char *) V->data,
mask ? ((const char *) mask->data) : nullptr,
(parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
scale, max_bias, m0, m1, n_head_log2,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
Q->nb[1], Q->nb[2], Q->nb[3],
K->nb[1], K->nb[2], K->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
);
CUDA_CHECK(cudaGetLastError());
if ((parallel_blocks) == 1) {
return;
}
const dim3 block_dim_combine(D, 1, 1);
const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
const int shmem_combine = 0;
flash_attn_combine_results<D, parallel_blocks>
<<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
CUDA_CHECK(cudaGetLastError());
}
template <int D, int cols_per_block, int nwarps, typename KQ_acc_t> void launch_fattn_f16(
const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask,
const int nsm, ggml_cuda_pool & pool, cudaStream_t main_stream
) {
const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3]; const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3];
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
if (4*blocks_num_pb1 < 2*nsm) { if (4*blocks_num_pb1 < 2*nsm) {
launch_fattn_f16_impl<D, cols_per_block, nwarps, 4, KQ_acc_t>(Q, K, V, KQV, mask, pool, main_stream); constexpr int parallel_blocks = 4;
fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
return; return;
} }
if (2*blocks_num_pb1 < 2*nsm) { if (2*blocks_num_pb1 < 2*nsm) {
launch_fattn_f16_impl<D, cols_per_block, nwarps, 2, KQ_acc_t>(Q, K, V, KQV, mask, pool, main_stream); constexpr int parallel_blocks = 2;
fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
return; return;
} }
launch_fattn_f16_impl<D, cols_per_block, nwarps, 1, KQ_acc_t>(Q, K, V, KQV, mask, pool, main_stream); constexpr int parallel_blocks = 1;
fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
} }
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * KQV = dst;
const ggml_tensor * Q = dst->src[0]; const ggml_tensor * Q = dst->src[0];
const ggml_tensor * K = dst->src[1];
const ggml_tensor * V = dst->src[2];
const ggml_tensor * mask = dst->src[3];
ggml_tensor * KQV = dst;
GGML_ASSERT(Q->type == GGML_TYPE_F32);
GGML_ASSERT(K->type == GGML_TYPE_F16);
GGML_ASSERT(V->type == GGML_TYPE_F16);
GGML_ASSERT(KQV->type == GGML_TYPE_F32);
GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
"the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
ggml_cuda_set_device(ctx.device); ggml_cuda_set_device(ctx.device);
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
const int32_t precision = KQV->op_params[2]; const int32_t precision = KQV->op_params[2];
if (!fast_fp16_available(cc)) { // On AMD the tile kernels perform poorly, use the vec kernel instead:
if (cc >= CC_OFFSET_AMD) {
if (precision == GGML_PREC_DEFAULT) {
ggml_cuda_flash_attn_ext_vec_f16_no_mma(ctx, dst);
} else {
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst); ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
}
return;
}
if (!fast_fp16_available(cc)) {
if (Q->ne[1] <= 8) {
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
} else {
ggml_cuda_flash_attn_ext_tile_f32(ctx, dst);
}
return; return;
} }
if (!fp16_mma_available(cc)) { if (!fp16_mma_available(cc)) {
if (Q->ne[1] <= 8) {
ggml_cuda_flash_attn_ext_vec_f16_no_mma(ctx, dst); ggml_cuda_flash_attn_ext_vec_f16_no_mma(ctx, dst);
} else {
ggml_cuda_flash_attn_ext_tile_f16(ctx, dst);
}
return; return;
} }
@ -562,22 +501,22 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
constexpr int nwarps = 4; constexpr int nwarps = 4;
switch (Q->ne[0]) { switch (Q->ne[0]) {
case 64: case 64:
launch_fattn_f16< 64, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); launch_fattn_f16< 64, cols_per_block, nwarps, float>(ctx, dst);
break; break;
case 80: case 80:
launch_fattn_f16< 80, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); launch_fattn_f16< 80, cols_per_block, nwarps, float>(ctx, dst);
break; break;
case 96: case 96:
launch_fattn_f16< 96, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); launch_fattn_f16< 96, cols_per_block, nwarps, float>(ctx, dst);
break; break;
case 112: case 112:
launch_fattn_f16<112, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); launch_fattn_f16<112, cols_per_block, nwarps, float>(ctx, dst);
break; break;
case 128: case 128:
launch_fattn_f16<128, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); launch_fattn_f16<128, cols_per_block, nwarps, float>(ctx, dst);
break; break;
case 256: case 256:
launch_fattn_f16<256, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); launch_fattn_f16<256, cols_per_block, nwarps, float>(ctx, dst);
break; break;
default: default:
GGML_ASSERT(false); GGML_ASSERT(false);
@ -588,22 +527,22 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
constexpr int nwarps = 4; constexpr int nwarps = 4;
switch (Q->ne[0]) { switch (Q->ne[0]) {
case 64: case 64:
launch_fattn_f16< 64, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); launch_fattn_f16< 64, cols_per_block, nwarps, float>(ctx, dst);
break; break;
case 80: case 80:
launch_fattn_f16< 80, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); launch_fattn_f16< 80, cols_per_block, nwarps, float>(ctx, dst);
break; break;
case 96: case 96:
launch_fattn_f16< 96, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); launch_fattn_f16< 96, cols_per_block, nwarps, float>(ctx, dst);
break; break;
case 112: case 112:
launch_fattn_f16<112, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); launch_fattn_f16<112, cols_per_block, nwarps, float>(ctx, dst);
break; break;
case 128: case 128:
launch_fattn_f16<128, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); launch_fattn_f16<128, cols_per_block, nwarps, float>(ctx, dst);
break; break;
// case 256: // case 256:
// launch_fattn_f16<256, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); // launch_fattn_f16<256, cols_per_block, nwarps, float>(ctx, dst);
// break; // break;
default: default:
GGML_ASSERT(false); GGML_ASSERT(false);
@ -623,16 +562,16 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
constexpr int nwarps = 4; constexpr int nwarps = 4;
switch (Q->ne[0]) { switch (Q->ne[0]) {
case 64: case 64:
launch_fattn_f16< 64, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); launch_fattn_f16< 64, cols_per_block, nwarps, half>(ctx, dst);
break; break;
case 96: case 96:
launch_fattn_f16< 96, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); launch_fattn_f16< 96, cols_per_block, nwarps, half>(ctx, dst);
break; break;
case 128: case 128:
launch_fattn_f16<128, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); launch_fattn_f16<128, cols_per_block, nwarps, half>(ctx, dst);
break; break;
case 256: case 256:
launch_fattn_f16<256, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); launch_fattn_f16<256, cols_per_block, nwarps, half>(ctx, dst);
break; break;
default: default:
GGML_ASSERT(false); GGML_ASSERT(false);
@ -646,22 +585,22 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
constexpr int nwarps = 4; constexpr int nwarps = 4;
switch (Q->ne[0]) { switch (Q->ne[0]) {
case 64: case 64:
launch_fattn_f16< 64, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); launch_fattn_f16< 64, cols_per_block, nwarps, half>(ctx, dst);
break; break;
case 80: case 80:
launch_fattn_f16< 80, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); launch_fattn_f16< 80, cols_per_block, nwarps, half>(ctx, dst);
break; break;
case 96: case 96:
launch_fattn_f16< 96, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); launch_fattn_f16< 96, cols_per_block, nwarps, half>(ctx, dst);
break; break;
case 112: case 112:
launch_fattn_f16<112, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); launch_fattn_f16<112, cols_per_block, nwarps, half>(ctx, dst);
break; break;
case 128: case 128:
launch_fattn_f16<128, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); launch_fattn_f16<128, cols_per_block, nwarps, half>(ctx, dst);
break; break;
case 256: case 256:
launch_fattn_f16<256, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); launch_fattn_f16<256, cols_per_block, nwarps, half>(ctx, dst);
break; break;
default: default:
GGML_ASSERT(false); GGML_ASSERT(false);
@ -674,22 +613,22 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
constexpr int nwarps = 4; constexpr int nwarps = 4;
switch (Q->ne[0]) { switch (Q->ne[0]) {
case 64: case 64:
launch_fattn_f16< 64, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); launch_fattn_f16< 64, cols_per_block, nwarps, half>(ctx, dst);
break; break;
case 80: case 80:
launch_fattn_f16< 80, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); launch_fattn_f16< 80, cols_per_block, nwarps, half>(ctx, dst);
break; break;
case 96: case 96:
launch_fattn_f16< 96, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); launch_fattn_f16< 96, cols_per_block, nwarps, half>(ctx, dst);
break; break;
case 112: case 112:
launch_fattn_f16<112, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); launch_fattn_f16<112, cols_per_block, nwarps, half>(ctx, dst);
break; break;
case 128: case 128:
launch_fattn_f16<128, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); launch_fattn_f16<128, cols_per_block, nwarps, half>(ctx, dst);
break; break;
case 256: case 256:
launch_fattn_f16<256, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); launch_fattn_f16<256, cols_per_block, nwarps, half>(ctx, dst);
break; break;
default: default:
GGML_ASSERT(false); GGML_ASSERT(false);

View file

@ -1,3 +1,4 @@
#include "common.cuh"
#include "softmax.cuh" #include "softmax.cuh"
template <typename T> template <typename T>
@ -23,17 +24,7 @@ static __global__ void soft_max_f32(const float * x, const T * mask, float * dst
const int warp_id = threadIdx.x / WARP_SIZE; const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE; const int lane_id = threadIdx.x % WARP_SIZE;
float slope = 1.0f; const float slope = get_alibi_slope(max_bias, rowx/nrows_y, n_head_log2, m0, m1);
// ALiBi
if (max_bias > 0.0f) {
const int h = rowx/nrows_y; // head index
const float base = h < n_head_log2 ? m0 : m1;
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
slope = powf(base, exph);
}
extern __shared__ float data_soft_max_f32[]; extern __shared__ float data_soft_max_f32[];
float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication

View file

@ -14,6 +14,12 @@
#include <stdlib.h> // for qsort #include <stdlib.h> // for qsort
#include <stdio.h> // for GGML_ASSERT #include <stdio.h> // for GGML_ASSERT
#define GROUP_MAX_EPS 1e-15f
#define GROUP_MAX_EPS_IQ3_XXS 1e-8f
#define GROUP_MAX_EPS_IQ2_S 1e-8f
#define GROUP_MAX_EPS_IQ1_M 1e-7f
#define GROUP_MAX_EPS_IQ1_S 1e-12f
#if defined(_MSC_VER) #if defined(_MSC_VER)
// disable "possible loss of data" to avoid warnings for hundreds of casts // disable "possible loss of data" to avoid warnings for hundreds of casts
// we should just be careful :) // we should just be careful :)
@ -1110,7 +1116,7 @@ static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t *
float ax = fabsf(x[i]); float ax = fabsf(x[i]);
if (ax > amax) { amax = ax; max = x[i]; } if (ax > amax) { amax = ax; max = x[i]; }
} }
if (amax < 1e-30f) { // all zero if (amax < GROUP_MAX_EPS) { // all zero
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
L[i] = 0; L[i] = 0;
} }
@ -1178,7 +1184,7 @@ static float make_q3_quants(int n, int nmax, const float * restrict x, int8_t *
float ax = fabsf(x[i]); float ax = fabsf(x[i]);
if (ax > amax) { amax = ax; max = x[i]; } if (ax > amax) { amax = ax; max = x[i]; }
} }
if (!amax) { // all zero if (amax < GROUP_MAX_EPS) { // all zero
for (int i = 0; i < n; ++i) { L[i] = 0; } for (int i = 0; i < n; ++i) { L[i] = 0; }
return 0.f; return 0.f;
} }
@ -2654,7 +2660,7 @@ void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict
} }
if (!max_abs_scale) { if (max_abs_scale < GROUP_MAX_EPS) {
memset(&y[i], 0, sizeof(block_q6_K)); memset(&y[i], 0, sizeof(block_q6_K));
y[i].d = GGML_FP32_TO_FP16(0.f); y[i].d = GGML_FP32_TO_FP16(0.f);
x += QK_K; x += QK_K;
@ -2806,7 +2812,7 @@ static void quantize_row_q6_K_impl(const float * restrict x, block_q6_K * restri
} }
if (!max_abs_scale) { if (max_abs_scale < GROUP_MAX_EPS) {
memset(&y[i], 0, sizeof(block_q6_K)); memset(&y[i], 0, sizeof(block_q6_K));
y[i].d = GGML_FP32_TO_FP16(0.f); y[i].d = GGML_FP32_TO_FP16(0.f);
x += QK_K; x += QK_K;
@ -12600,7 +12606,7 @@ static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict
} }
float max = xval[0]; float max = xval[0];
for (int i = 1; i < 32; ++i) max = MAX(max, xval[i]); for (int i = 1; i < 32; ++i) max = MAX(max, xval[i]);
if (!max) { if (max < GROUP_MAX_EPS) {
scales[ib] = 0; scales[ib] = 0;
memset(L, 0, 32); memset(L, 0, 32);
continue; continue;
@ -12776,7 +12782,7 @@ static void quantize_row_iq2_xs_impl(const float * restrict x, void * restrict v
} }
float max = xval[0]; float max = xval[0];
for (int i = 1; i < 16; ++i) max = MAX(max, xval[i]); for (int i = 1; i < 16; ++i) max = MAX(max, xval[i]);
if (!max) { if (max < GROUP_MAX_EPS) {
scales[ib] = 0; scales[ib] = 0;
memset(L, 0, 16); memset(L, 0, 16);
continue; continue;
@ -13217,7 +13223,7 @@ static void quantize_row_iq3_xxs_impl(int grid_size, const float * restrict x, v
} }
float max = xval[0]; float max = xval[0];
for (int i = 1; i < 32; ++i) max = MAX(max, xval[i]); for (int i = 1; i < 32; ++i) max = MAX(max, xval[i]);
if (!max) { if (max < GROUP_MAX_EPS_IQ3_XXS) {
scales[ib] = 0; scales[ib] = 0;
memset(L, 0, 32); memset(L, 0, 32);
continue; continue;
@ -13757,7 +13763,7 @@ static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy
for (int i = 0; i < block_size; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]); for (int i = 0; i < block_size; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
float max = fabsf(xb[0]); float max = fabsf(xb[0]);
for (int i = 1; i < block_size; ++i) max = MAX(max, fabsf(xb[i])); for (int i = 1; i < block_size; ++i) max = MAX(max, fabsf(xb[i]));
if (!max) { if (max < GROUP_MAX_EPS_IQ1_S) {
scales[ib] = 0; scales[ib] = 0;
memset(L, 1, block_size); memset(L, 1, block_size);
continue; continue;
@ -13945,7 +13951,7 @@ static void quantize_row_iq1_m_impl(const float * restrict x, void * restrict vy
} }
float max = fabsf(xb[0]); float max = fabsf(xb[0]);
for (int i = 1; i < block_size; ++i) max = MAX(max, fabsf(xb[i])); for (int i = 1; i < block_size; ++i) max = MAX(max, fabsf(xb[i]));
if (!max) { if (max < GROUP_MAX_EPS_IQ1_M) {
scales[ib] = 0; scales[ib] = 0;
memset(L, 1, block_size); memset(L, 1, block_size);
continue; continue;
@ -14209,7 +14215,7 @@ static void quantize_row_iq4_nl_impl(const int super_block_size, const int block
amax = ax; max = xb[j]; amax = ax; max = xb[j];
} }
} }
if (!amax) { if (amax < GROUP_MAX_EPS) {
scales[ib] = 0; scales[ib] = 0;
continue; continue;
} }
@ -14430,7 +14436,7 @@ static void quantize_row_iq2_s_impl(const float * restrict x, void * restrict vy
} }
float max = xval[0]; float max = xval[0];
for (int i = 1; i < 16; ++i) max = MAX(max, xval[i]); for (int i = 1; i < 16; ++i) max = MAX(max, xval[i]);
if (!max) { if (max < GROUP_MAX_EPS_IQ2_S) {
scales[ib] = 0; scales[ib] = 0;
continue; continue;
} }

View file

@ -134,7 +134,13 @@ static bool set_no_delay(sockfd_t sockfd) {
int flag = 1; int flag = 1;
// set TCP_NODELAY to disable Nagle's algorithm // set TCP_NODELAY to disable Nagle's algorithm
int ret = setsockopt(sockfd, IPPROTO_TCP, TCP_NODELAY, (char *)&flag, sizeof(int)); int ret = setsockopt(sockfd, IPPROTO_TCP, TCP_NODELAY, (char *)&flag, sizeof(int));
return ret >= 0; return ret == 0;
}
static bool set_reuse_addr(sockfd_t sockfd) {
int flag = 1;
int ret = setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, (char *)&flag, sizeof(int));
return ret == 0;
} }
static std::shared_ptr<socket_t> socket_connect(const char * host, int port) { static std::shared_ptr<socket_t> socket_connect(const char * host, int port) {
@ -181,7 +187,10 @@ static std::shared_ptr<socket_t> create_server_socket(const char * host, int por
if (sock == nullptr) { if (sock == nullptr) {
return nullptr; return nullptr;
} }
if (!set_reuse_addr(sockfd)) {
fprintf(stderr, "Failed to set SO_REUSEADDR\n");
return nullptr;
}
struct sockaddr_in serv_addr; struct sockaddr_in serv_addr;
serv_addr.sin_family = AF_INET; serv_addr.sin_family = AF_INET;
serv_addr.sin_addr.s_addr = inet_addr(host); serv_addr.sin_addr.s_addr = inet_addr(host);

File diff suppressed because it is too large Load diff

View file

@ -294,7 +294,6 @@ struct vk_op_rope_neox_push_constants {
struct vk_op_soft_max_push_constants { struct vk_op_soft_max_push_constants {
uint32_t KX; uint32_t KX;
uint32_t KY; uint32_t KY;
uint32_t KZ;
float scale; float scale;
float max_bias; float max_bias;
float m0; float m0;
@ -304,7 +303,8 @@ struct vk_op_soft_max_push_constants {
struct vk_op_argsort_push_constants { struct vk_op_argsort_push_constants {
uint32_t ncols; uint32_t ncols;
bool ascending; uint32_t ncols_pad;
int32_t order;
}; };
// Allow pre-recording command buffers // Allow pre-recording command buffers
@ -1501,8 +1501,8 @@ static void ggml_vk_load_shaders(ggml_backend_vk_context * ctx) {
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(ctx, ctx->device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(ctx, ctx->device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(ctx, ctx->device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_rope_f32, "rope_f32", rope_f32_len, rope_f32_data, "main", 3, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); ggml_vk_create_pipeline(ctx, ctx->device->pipeline_rope_f32, "rope_f32", rope_f32_len, rope_f32_data, "main", 3, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_rope_f16, "rope_f16", rope_f16_len, rope_f16_data, "main", 3, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); ggml_vk_create_pipeline(ctx, ctx->device->pipeline_rope_f16, "rope_f16", rope_f16_len, rope_f16_data, "main", 3, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
@ -3752,7 +3752,7 @@ static void ggml_vk_op_repeat(ggml_backend_vk_context * ctx, vk_context * subctx
} }
static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op) { static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, ggml_op op) {
switch (op) { switch (op) {
case GGML_OP_ADD: case GGML_OP_ADD:
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
@ -3834,7 +3834,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) { if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_soft_max_f32; return ctx->device->pipeline_soft_max_f32;
} }
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16 && src2->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) { if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_soft_max_f32_f16; return ctx->device->pipeline_soft_max_f32_f16;
} }
return nullptr; return nullptr;
@ -3900,15 +3900,12 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
} }
template<typename PC> template<typename PC>
static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op, const PC&& pc) { static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, ggml_op op, const PC&& pc) {
#ifdef GGML_VULKAN_DEBUG #ifdef GGML_VULKAN_DEBUG
std::cerr << "ggml_vk_op_f32((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", backend=" << src0->backend << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; std::cerr << "ggml_vk_op_f32((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", backend=" << src0->backend << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
if (src1 != nullptr) { if (src1 != nullptr) {
std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", backend=" << src1->backend << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", backend=" << src1->backend << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
} }
if (src2 != nullptr) {
std::cerr << "), (" << src2 << ", name=" << src2->name << ", type=" << src2->type << ", backend=" << src2->backend << ", ne0=" << src2->ne[0] << ", ne1=" << src2->ne[1] << ", ne2=" << src2->ne[2] << ", ne3=" << src2->ne[3] << ", nb0=" << src2->nb[0] << ", nb1=" << src2->nb[1] << ", nb2=" << src2->nb[2] << ", nb3=" << src2->nb[3];
}
std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", backend=" << dst->backend << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "), " << ggml_op_name(op) << ")" << std::endl; std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", backend=" << dst->backend << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "), " << ggml_op_name(op) << ")" << std::endl;
#endif #endif
GGML_ASSERT(op == GGML_OP_GET_ROWS || (!ggml_is_quantized(src0->type) && (src1 == nullptr || !ggml_is_quantized(src1->type)))); // NOLINT GGML_ASSERT(op == GGML_OP_GET_ROWS || (!ggml_is_quantized(src0->type) && (src1 == nullptr || !ggml_is_quantized(src1->type)))); // NOLINT
@ -3929,10 +3926,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c
const uint64_t nb2 = dst->nb[2]; const uint64_t nb2 = dst->nb[2];
const uint64_t nb3 = dst->nb[3]; const uint64_t nb3 = dst->nb[3];
const bool use_src2 = src2 != nullptr; vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, dst, op);
const uint64_t ne2 = use_src2 ? src2->ne[0] * src2->ne[1] : 0;
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, dst, op);
ggml_vk_func_t op_func; ggml_vk_func_t op_func;
if (pipeline == nullptr) { if (pipeline == nullptr) {
@ -3955,18 +3949,15 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c
ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) dst->extra; ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) dst->extra;
ggml_tensor_extra_gpu * extra_src0 = (ggml_tensor_extra_gpu *) src0->extra; ggml_tensor_extra_gpu * extra_src0 = (ggml_tensor_extra_gpu *) src0->extra;
ggml_tensor_extra_gpu * extra_src1 = use_src1 ? (ggml_tensor_extra_gpu *) src1->extra : nullptr; ggml_tensor_extra_gpu * extra_src1 = use_src1 ? (ggml_tensor_extra_gpu *) src1->extra : nullptr;
ggml_tensor_extra_gpu * extra_src2 = use_src2 ? (ggml_tensor_extra_gpu *) src2->extra : nullptr;
vk_buffer d_X = nullptr; vk_buffer d_X = nullptr;
size_t x_buf_offset = 0; size_t x_buf_offset = 0;
vk_buffer d_Y = nullptr; vk_buffer d_Y = nullptr;
size_t y_buf_offset = 0; size_t y_buf_offset = 0;
vk_buffer d_Z = nullptr; vk_buffer d_Z = nullptr;
size_t z_buf_offset = 0;
bool src0_uma = false; bool src0_uma = false;
bool src1_uma = false; bool src1_uma = false;
bool src2_uma = false;
if (ctx->device->uma) { if (ctx->device->uma) {
ggml_vk_host_get(ctx, src0->data, d_X, x_buf_offset); ggml_vk_host_get(ctx, src0->data, d_X, x_buf_offset);
@ -3975,15 +3966,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c
ggml_vk_host_get(ctx, src1->data, d_Y, y_buf_offset); ggml_vk_host_get(ctx, src1->data, d_Y, y_buf_offset);
src1_uma = d_Y != nullptr; src1_uma = d_Y != nullptr;
} }
if (use_src2) {
ggml_vk_host_get(ctx, src1->data, d_Z, z_buf_offset);
src2_uma = d_Z != nullptr;
}
} }
uint64_t x_sz = ggml_vk_align_size(ggml_type_size(src0->type)/ggml_blck_size(src0->type) * ne0, ctx->device->properties.limits.minStorageBufferOffsetAlignment); uint64_t x_sz = ggml_vk_align_size(ggml_type_size(src0->type)/ggml_blck_size(src0->type) * ne0, ctx->device->properties.limits.minStorageBufferOffsetAlignment);
uint64_t y_sz = use_src1 ? ggml_vk_align_size(ggml_type_size(src1->type) * ne1, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : 0; uint64_t y_sz = use_src1 ? ggml_vk_align_size(ggml_type_size(src1->type) * ne1, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : 0;
uint64_t z_sz = use_src2 ? ggml_vk_align_size(ggml_type_size(src2->type) * ne2, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : 0;
uint64_t d_sz = ggml_type_size(dst->type) * ne0; uint64_t d_sz = ggml_type_size(dst->type) * ne0;
vk_buffer d_D = extra->buffer_gpu.lock(); vk_buffer d_D = extra->buffer_gpu.lock();
@ -4007,12 +3993,6 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c
GGML_ASSERT(d_Y != nullptr); GGML_ASSERT(d_Y != nullptr);
} }
if (use_src2 && !src2_uma) {
d_Z = extra_src2->buffer_gpu.lock();
z_buf_offset = extra_src2->offset;
GGML_ASSERT(d_Z != nullptr);
}
if (op_supports_incontiguous) { if (op_supports_incontiguous) {
x_sz = ggml_nbytes(src0); x_sz = ggml_nbytes(src0);
y_sz = use_src1 ? ggml_nbytes(src1) : 0; y_sz = use_src1 ? ggml_nbytes(src1) : 0;
@ -4048,6 +4028,9 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c
case GGML_OP_GET_ROWS: case GGML_OP_GET_ROWS:
elements = { (uint32_t)ne00, (uint32_t)ne10, (uint32_t)(ne11 * ne12) }; elements = { (uint32_t)ne00, (uint32_t)ne10, (uint32_t)(ne11 * ne12) };
break; break;
case GGML_OP_ARGSORT:
elements = { (uint32_t)ne00, (uint32_t)ggml_nrows(src0), 1 };
break;
default: default:
elements = { (uint32_t)ggml_nelements(src0), 1, 1 }; elements = { (uint32_t)ggml_nelements(src0), 1, 1 };
break; break;
@ -4066,7 +4049,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c
} }
if (op == GGML_OP_SOFT_MAX) { if (op == GGML_OP_SOFT_MAX) {
// Empty src1 and src2 are possible on soft_max, but the shader needs buffers // Empty src1 is possible on soft_max, but the shader needs a buffer
vk_subbuffer subbuf_y; vk_subbuffer subbuf_y;
if (use_src1) { if (use_src1) {
subbuf_y = { d_Y, y_buf_offset, y_sz }; subbuf_y = { d_Y, y_buf_offset, y_sz };
@ -4074,15 +4057,8 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c
subbuf_y = { d_X, 0, d_X->size }; subbuf_y = { d_X, 0, d_X->size };
} }
vk_subbuffer subbuf_z;
if (use_src2) {
subbuf_z = { d_Z, z_buf_offset, z_sz };
} else {
subbuf_z = { d_X, 0, d_X->size };
}
ggml_vk_sync_buffers(subctx); ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { { d_X, x_buf_offset, x_sz }, subbuf_y, subbuf_z, { d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { { d_X, x_buf_offset, x_sz }, subbuf_y, { d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
} else if (use_src1) { } else if (use_src1) {
ggml_vk_sync_buffers(subctx); ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { { d_X, x_buf_offset, x_sz }, { d_Y, y_buf_offset, y_sz }, { d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { { d_X, x_buf_offset, x_sz }, { d_Y, y_buf_offset, y_sz }, { d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
@ -4099,13 +4075,13 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c
} }
} else { } else {
GGML_ASSERT(op != GGML_OP_SOFT_MAX); GGML_ASSERT(op != GGML_OP_SOFT_MAX);
GGML_ASSERT(op != GGML_OP_ARGSORT);
ggml_pipeline_allocate_descriptor_sets(ctx, pipeline, ne02 * ne03); ggml_pipeline_allocate_descriptor_sets(ctx, pipeline, ne02 * ne03);
switch (dst->op) { switch (dst->op) {
case GGML_OP_NORM: case GGML_OP_NORM:
case GGML_OP_RMS_NORM: case GGML_OP_RMS_NORM:
case GGML_OP_SOFT_MAX:
elements = { (uint32_t)ne01, 1, 1 }; elements = { (uint32_t)ne01, 1, 1 };
break; break;
case GGML_OP_DIAG_MASK_INF: case GGML_OP_DIAG_MASK_INF:
@ -4145,7 +4121,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c
} }
static void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { static void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_REPEAT, { (uint32_t)ggml_nelements(src0), (uint32_t)ggml_nelements(src1), 0.0f, 0.0f }); ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, dst, GGML_OP_REPEAT, { (uint32_t)ggml_nelements(src0), (uint32_t)ggml_nelements(src1), 0.0f, 0.0f });
} }
static void ggml_vk_get_rows(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { static void ggml_vk_get_rows(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@ -4153,7 +4129,7 @@ static void ggml_vk_get_rows(ggml_backend_vk_context * ctx, vk_context * subctx,
const uint32_t src1_type_size = ggml_type_size(src1->type); const uint32_t src1_type_size = ggml_type_size(src1->type);
const uint32_t dst_type_size = ggml_type_size(dst->type); const uint32_t dst_type_size = ggml_type_size(dst->type);
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_GET_ROWS, { ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, dst, GGML_OP_GET_ROWS, {
(uint32_t)ggml_nelements(src0), (uint32_t)ggml_nelements(src0),
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
@ -4168,7 +4144,7 @@ static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context * subctx, cons
const uint32_t src1_type_size = ggml_type_size(src1->type); const uint32_t src1_type_size = ggml_type_size(src1->type);
const uint32_t dst_type_size = ggml_type_size(dst->type); const uint32_t dst_type_size = ggml_type_size(dst->type);
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_ADD, { ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, dst, GGML_OP_ADD, {
(uint32_t)ggml_nelements(src0), (uint32_t)ggml_nelements(src0),
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
@ -4183,7 +4159,7 @@ static void ggml_vk_mul(ggml_backend_vk_context * ctx, vk_context * subctx, cons
const uint32_t src1_type_size = ggml_type_size(src1->type); const uint32_t src1_type_size = ggml_type_size(src1->type);
const uint32_t dst_type_size = ggml_type_size(dst->type); const uint32_t dst_type_size = ggml_type_size(dst->type);
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_MUL, { ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, dst, GGML_OP_MUL, {
(uint32_t)ggml_nelements(src0), (uint32_t)ggml_nelements(src0),
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
@ -4198,7 +4174,7 @@ static void ggml_vk_scale(ggml_backend_vk_context * ctx, vk_context * subctx, co
const uint32_t src0_type_size = ggml_type_size(src0->type); const uint32_t src0_type_size = ggml_type_size(src0->type);
const uint32_t dst_type_size = ggml_type_size(dst->type); const uint32_t dst_type_size = ggml_type_size(dst->type);
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SCALE, { ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, dst, GGML_OP_SCALE, {
(uint32_t)ggml_nelements(src0), (uint32_t)ggml_nelements(src0),
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
@ -4211,7 +4187,7 @@ static void ggml_vk_sqr(ggml_backend_vk_context * ctx, vk_context * subctx, cons
const uint32_t src0_type_size = ggml_type_size(src0->type); const uint32_t src0_type_size = ggml_type_size(src0->type);
const uint32_t dst_type_size = ggml_type_size(dst->type); const uint32_t dst_type_size = ggml_type_size(dst->type);
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SQR, { ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, dst, GGML_OP_SQR, {
(uint32_t)ggml_nelements(src0), (uint32_t)ggml_nelements(src0),
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
@ -4225,7 +4201,7 @@ static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context * subctx, co
const uint32_t src0_type_size = ggml_type_size(src0->type); const uint32_t src0_type_size = ggml_type_size(src0->type);
const uint32_t dst_type_size = ggml_type_size(dst->type); const uint32_t dst_type_size = ggml_type_size(dst->type);
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CLAMP, { ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, dst, GGML_OP_CLAMP, {
(uint32_t)ggml_nelements(src0), (uint32_t)ggml_nelements(src0),
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
@ -4240,7 +4216,7 @@ static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context * subctx, cons
const uint32_t dst_type_size = ggml_type_size(dst->type); const uint32_t dst_type_size = ggml_type_size(dst->type);
const uint32_t d_offset = (extra->offset % ctx->device->properties.limits.minStorageBufferOffsetAlignment) / dst_type_size; const uint32_t d_offset = (extra->offset % ctx->device->properties.limits.minStorageBufferOffsetAlignment) / dst_type_size;
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CPY, { ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, dst, GGML_OP_CPY, {
(uint32_t)ggml_nelements(src0), (uint32_t)ggml_nelements(src0),
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
@ -4252,24 +4228,24 @@ static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context * subctx, cons
static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) { static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
float * op_params = (float *)dst->op_params; float * op_params = (float *)dst->op_params;
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }); ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, dst, GGML_OP_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f });
} }
static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) { static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
float * op_params = (float *)dst->op_params; float * op_params = (float *)dst->op_params;
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }); ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, dst, GGML_OP_RMS_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f });
} }
static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) { static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }); ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f });
} }
static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) { static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
int32_t * op_params = (int32_t *)dst->op_params; int32_t * op_params = (int32_t *)dst->op_params;
ggml_vk_op_f32<vk_op_diag_mask_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_DIAG_MASK_INF, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0] }); ggml_vk_op_f32<vk_op_diag_mask_push_constants>(ctx, subctx, src0, nullptr, dst, GGML_OP_DIAG_MASK_INF, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0] });
} }
static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
float * op_params = (float *)dst->op_params; float * op_params = (float *)dst->op_params;
float scale = op_params[0]; float scale = op_params[0];
@ -4285,13 +4261,9 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context * subctx,
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
#pragma message("TODO: src2 is no longer used in soft_max - should be removed and ALiBi calculation should be updated") ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, dst, GGML_OP_SOFT_MAX, {
#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/7192")
ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_SOFT_MAX, {
ncols, ncols,
src1 != nullptr ? nrows_y : (uint32_t)0, src1 != nullptr ? nrows_y : (uint32_t)0,
src2 != nullptr ? (uint32_t)1 : (uint32_t)0,
scale, max_bias, scale, max_bias,
m0, m1, m0, m1,
n_head_log2, n_head_log2,
@ -4321,15 +4293,39 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context * subctx, con
if (is_neox) { if (is_neox) {
const float theta_scale = powf(freq_base, -2.0f/n_dims); const float theta_scale = powf(freq_base, -2.0f/n_dims);
const float inv_ndims = -1.0f / n_dims; const float inv_ndims = -1.0f / n_dims;
ggml_vk_op_f32<vk_op_rope_neox_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_ROPE, { (uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1], freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1], 0.0f, 0.0f}, theta_scale, inv_ndims }); ggml_vk_op_f32<vk_op_rope_neox_push_constants>(ctx, subctx, src0, src1, dst, GGML_OP_ROPE, {
(uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1],
freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1], 0.0f, 0.0f}, theta_scale, inv_ndims
});
} else { } else {
ggml_vk_op_f32<vk_op_rope_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_ROPE, { (uint32_t)src0->ne[0], freq_scale, (uint32_t)src0->ne[1], freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1], 0.0f, 0.0f} }); ggml_vk_op_f32<vk_op_rope_push_constants>(ctx, subctx, src0, src1, dst, GGML_OP_ROPE, {
(uint32_t)src0->ne[0], freq_scale, (uint32_t)src0->ne[1],
freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1], 0.0f, 0.0f}
});
} }
} }
static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) { static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
int32_t * op_params = (int32_t *)dst->op_params; int32_t * op_params = (int32_t *)dst->op_params;
ggml_vk_op_f32<vk_op_argsort_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ARGSORT, { (uint32_t)src0->ne[0], ((ggml_sort_order) op_params[0]) == GGML_SORT_ORDER_ASC });
uint32_t ncols = src0->ne[0];
uint32_t ncols_pad = 1;
while (ncols_pad < ncols) {
ncols_pad *= 2;
}
GGML_ASSERT(ncols_pad <= 1024);
std::cerr << "ncols=" << ncols << " ncols_pad=" << ncols_pad << " ascending=" << op_params[0] << std::endl;
std::cerr << ((ggml_sort_order) op_params[0]) << " " << GGML_SORT_ORDER_ASC << std::endl;
ggml_vk_op_f32<vk_op_argsort_push_constants>(ctx, subctx, src0, nullptr, dst, GGML_OP_ARGSORT, {
ncols,
ncols_pad,
op_params[0],
});
} }
#ifdef GGML_VULKAN_RUN_TESTS #ifdef GGML_VULKAN_RUN_TESTS
@ -5432,7 +5428,6 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
const ggml_tensor * src0 = node->src[0]; const ggml_tensor * src0 = node->src[0];
const ggml_tensor * src1 = node->src[1]; const ggml_tensor * src1 = node->src[1];
const ggml_tensor * src2 = node->src[2];
ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) node->extra; ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) node->extra;
@ -5547,7 +5542,7 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
break; break;
case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX:
ggml_vk_soft_max(ctx, ctx->compute_ctx, src0, src1, src2, node); ggml_vk_soft_max(ctx, ctx->compute_ctx, src0, src1, node);
break; break;
case GGML_OP_ROPE: case GGML_OP_ROPE:
@ -6548,7 +6543,7 @@ static void ggml_vk_print_graph_origin(const ggml_tensor * tensor, std::vector<c
} }
static void ggml_vk_print_tensor_area(const ggml_tensor * tensor, const void * data, int i0, int i1, int i2, int i3) { static void ggml_vk_print_tensor_area(const ggml_tensor * tensor, const void * data, int i0, int i1, int i2, int i3) {
if (tensor->type != GGML_TYPE_F32 && tensor->type != GGML_TYPE_F16) { if (tensor->type != GGML_TYPE_F32 && tensor->type != GGML_TYPE_F16 && tensor->type != GGML_TYPE_I32) {
return; return;
} }
i0 = std::max(i0, 5); i0 = std::max(i0, 5);
@ -6569,6 +6564,8 @@ static void ggml_vk_print_tensor_area(const ggml_tensor * tensor, const void * d
val = *(const float *) ((const char *) data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0]); val = *(const float *) ((const char *) data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0]);
} else if (tensor->type == GGML_TYPE_F16) { } else if (tensor->type == GGML_TYPE_F16) {
val = ggml_fp16_to_fp32(*(const ggml_fp16_t *) ((const char *) data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0])); val = ggml_fp16_to_fp32(*(const ggml_fp16_t *) ((const char *) data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0]));
} else if (tensor->type == GGML_TYPE_I32) {
val = *(const int32_t *) ((const char *) data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0]);
} else { } else {
GGML_ASSERT(false); GGML_ASSERT(false);
} }
@ -6671,7 +6668,6 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_
ggml_tensor * src0 = tensor->src[0]; ggml_tensor * src0 = tensor->src[0];
ggml_tensor * src1 = tensor->src[1]; ggml_tensor * src1 = tensor->src[1];
ggml_tensor * src2 = tensor->src[2];
struct ggml_init_params iparams = { struct ggml_init_params iparams = {
/*.mem_size =*/ 1024*1024*1024, /*.mem_size =*/ 1024*1024*1024,
@ -6798,66 +6794,6 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_
ggml_vk_check_tensor(std::string(ggml_op_name(tensor->op)) + "->src1", src1_clone); ggml_vk_check_tensor(std::string(ggml_op_name(tensor->op)) + "->src1", src1_clone);
} }
if (src2 != nullptr) {
src2_clone = ggml_dup_tensor(ggml_ctx, src2);
src2_size = ggml_nbytes(src2);
src2_buffer = malloc(src2_size);
src2_clone->data = src2_buffer;
if (src2->backend == GGML_BACKEND_TYPE_CPU) {
memcpy(src2_clone->data, src2->data, src2_size);
memcpy(src2_clone->nb, src2->nb, sizeof(size_t) * GGML_MAX_DIMS);
} else if (src2->backend == GGML_BACKEND_TYPE_GPU) {
ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) src2->extra;
vk_buffer buf = extra->buffer_gpu.lock();
uint64_t offset = extra->offset;
if (!ggml_is_contiguous(src2) && ggml_vk_dim01_contiguous(src2)) {
for (int i3 = 0; i3 < src2->ne[3]; i3++) {
for (int i2 = 0; i2 < src2->ne[2]; i2++) {
const int idx = i3*src2->ne[2] + i2;
ggml_vk_buffer_read(ctx, buf, offset + idx * src2->nb[2], ((char *)src2_clone->data + idx * src2_clone->nb[2]), src2->ne[1] * src2->nb[1]);
}
}
src2_clone->nb[0] = src2->nb[0];
src2_clone->nb[1] = src2->nb[1];
for (int i = 2; i < GGML_MAX_DIMS; i++) {
src2_clone->nb[i] = src2_clone->nb[i - 1]*src2_clone->ne[i - 1];
}
} else {
if (offset + src2_size >= buf->size) {
src2_size = buf->size - offset;
}
ggml_vk_buffer_read(ctx, buf, offset, src2_clone->data, src2_size);
memcpy(src2_clone->nb, src2->nb, sizeof(size_t) * GGML_MAX_DIMS);
}
} else {
GGML_ASSERT(false);
}
if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {
ggml_vk_print_tensor(ctx, src2, "src2");
std::cerr << "TENSOR CHECK: " << ggml_op_name(src2_clone->op) << " (check " << check_counter << ")" << std::endl;
std::cerr << "src2_clone=" << tensor << " src2_clone->backend: " << src2_clone->backend << " src2_clone->type: " << ggml_type_name(src2_clone->type) << " ne0=" << src2_clone->ne[0] << " nb0=" << src2_clone->nb[0] << " ne1=" << src2_clone->ne[1] << " nb1=" << src2_clone->nb[1] << " ne2=" << src2_clone->ne[2] << " nb2=" << src2_clone->nb[2] << " ne3=" << src2_clone->ne[3] << " nb3=" << src2_clone->nb[3] << std::endl;
if (src2->src[0] != nullptr) {
std::cerr << "src2->src[0]=" << src2->src[0] << " op=" << ggml_op_name(src2->src[0]->op) << " type=" << ggml_type_name(src2->src[0]->type) << " backend=" << src2->src[0]->backend << " ne0=" << src2->src[0]->ne[0] << " nb0=" << src2->src[0]->nb[0] << " ne1=" << src2->src[0]->ne[1] << " nb1=" << src2->src[0]->nb[1] << " ne2=" << src2->src[0]->ne[2] << " nb2=" << src2->src[0]->nb[2] << " ne3=" << src2->src[0]->ne[3] << " nb3=" << src2->src[0]->nb[3] << std::endl;
}
if (src2->src[1] != nullptr) {
std::cerr << "src2->src[1]=" << src2->src[1] << " op=" << ggml_op_name(src2->src[1]->op) << " type=" << ggml_type_name(src2->src[1]->type) << " backend=" << src2->src[1]->backend << " ne0=" << src2->src[1]->ne[0] << " nb0=" << src2->src[1]->nb[0] << " ne1=" << src2->src[1]->ne[1] << " nb1=" << src2->src[1]->nb[1] << " ne2=" << src2->src[1]->ne[2] << " nb2=" << src2->src[1]->nb[2] << " ne3=" << src2->src[1]->ne[3] << " nb3=" << src2->src[1]->nb[3] << std::endl;
}
std::cerr << std::endl << "Result:" << std::endl;
ggml_vk_print_tensor_area(src2_clone, src2_clone->data, 5, 5, 0, 0);
std::cerr << std::endl;
std::cerr << std::endl << "Result:" << std::endl;
ggml_vk_print_tensor_area(src2_clone, src2_clone->data, 5, 5, 1, 0);
std::cerr << std::endl;
std::vector<const ggml_tensor *> done;
ggml_vk_print_graph_origin(src2_clone, done);
}
ggml_vk_check_tensor(std::string(ggml_op_name(tensor->op)) + "->src2", src2_clone);
}
if (tensor->op == GGML_OP_MUL_MAT) { if (tensor->op == GGML_OP_MUL_MAT) {
tensor_clone = ggml_mul_mat(ggml_ctx, src0_clone, src1_clone); tensor_clone = ggml_mul_mat(ggml_ctx, src0_clone, src1_clone);
@ -6877,7 +6813,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_
tensor_clone = ggml_rms_norm(ggml_ctx, src0_clone, *(float *)tensor->op_params); tensor_clone = ggml_rms_norm(ggml_ctx, src0_clone, *(float *)tensor->op_params);
} else if (tensor->op == GGML_OP_SOFT_MAX) { } else if (tensor->op == GGML_OP_SOFT_MAX) {
if (src1 != nullptr) { if (src1 != nullptr) {
tensor_clone = ggml_soft_max_ext(ggml_ctx, src0_clone, src1_clone, src2_clone, ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]); tensor_clone = ggml_soft_max_ext(ggml_ctx, src0_clone, src1_clone, ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
} else { } else {
tensor_clone = ggml_soft_max(ggml_ctx, src0_clone); tensor_clone = ggml_soft_max(ggml_ctx, src0_clone);
} }
@ -6964,9 +6900,6 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_
if (src1 != nullptr) { if (src1 != nullptr) {
free(src1_buffer); free(src1_buffer);
} }
if (src2 != nullptr) {
free(src2_buffer);
}
ggml_free(ggml_ctx); ggml_free(ggml_ctx);
} }
@ -7026,8 +6959,11 @@ static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_compute_
} else if (tensor->type == GGML_TYPE_F16) { } else if (tensor->type == GGML_TYPE_F16) {
correct = ggml_fp16_to_fp32(*(ggml_fp16_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0])); correct = ggml_fp16_to_fp32(*(ggml_fp16_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]));
result = ggml_fp16_to_fp32(*(ggml_fp16_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0])); result = ggml_fp16_to_fp32(*(ggml_fp16_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]));
} else if (tensor->type == GGML_TYPE_I32) {
correct = *(int32_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]);
result = *(int32_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]);
} else { } else {
std::cerr << "comp_size=" << comp_size << " but required is " << (i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]) << std::endl; std::cerr << "Results check not implemented for type " << ggml_type_name(tensor->type) << std::endl;
} }
} else { } else {
std::cerr << "Missing debug code for type " << ggml_type_name(tensor->type) << std::endl; std::cerr << "Missing debug code for type " << ggml_type_name(tensor->type) << std::endl;

6
ggml.c
View file

@ -2076,7 +2076,7 @@ inline static float ggml_silu_f32(float x) {
return x/(1.0f + expf(-x)); return x/(1.0f + expf(-x));
} }
#if defined(__ARM_NEON) #if defined(__ARM_NEON) && defined(__aarch64__)
// adapted from arm limited optimized routine // adapted from arm limited optimized routine
// the maximum error is 1.45358 plus 0.5 ulps // the maximum error is 1.45358 plus 0.5 ulps
@ -2288,7 +2288,7 @@ static void ggml_vec_silu_f32(const int n, float * y, const float * x) {
for (; i + 3 < n; i += 4) { for (; i + 3 < n; i += 4) {
_mm_storeu_ps(y + i, ggml_v_silu(_mm_loadu_ps(x + i))); _mm_storeu_ps(y + i, ggml_v_silu(_mm_loadu_ps(x + i)));
} }
#elif defined(__ARM_NEON) #elif defined(__ARM_NEON) && defined(__aarch64__)
for (; i + 3 < n; i += 4) { for (; i + 3 < n; i += 4) {
vst1q_f32(y + i, ggml_v_silu(vld1q_f32(x + i))); vst1q_f32(y + i, ggml_v_silu(vld1q_f32(x + i)));
} }
@ -2335,7 +2335,7 @@ static ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x,
#endif #endif
sum += (ggml_float)_mm_cvtss_f32(val); sum += (ggml_float)_mm_cvtss_f32(val);
} }
#elif defined(__ARM_NEON) #elif defined(__ARM_NEON) && defined(__aarch64__)
for (; i + 3 < n; i += 4) { for (; i + 3 < n; i += 4) {
float32x4_t val = ggml_v_expf(vsubq_f32(vld1q_f32(x + i), float32x4_t val = ggml_v_expf(vsubq_f32(vld1q_f32(x + i),
vdupq_n_f32(max))); vdupq_n_f32(max)));

View file

@ -2432,7 +2432,6 @@ layout (push_constant) uniform parameter
{ {
uint KX; uint KX;
uint KY; uint KY;
uint KZ;
float scale; float scale;
float max_bias; float max_bias;
float m0; float m0;
@ -2449,8 +2448,7 @@ layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) readonly buffer Y {B_TYPE data_b[];}; layout (binding = 1) readonly buffer Y {B_TYPE data_b[];};
layout (binding = 2) readonly buffer Z {C_TYPE data_c[];}; layout (binding = 2) buffer D {D_TYPE data_d[];};
layout (binding = 3) buffer D {D_TYPE data_d[];};
shared FLOAT_TYPE vals[BLOCK_SIZE]; shared FLOAT_TYPE vals[BLOCK_SIZE];
@ -2459,7 +2457,7 @@ void main() {
const uint rowx = gl_WorkGroupID.x; const uint rowx = gl_WorkGroupID.x;
const uint rowy = rowx % p.KY; const uint rowy = rowx % p.KY;
float slope = 0.0f; float slope = 1.0f;
// ALiBi // ALiBi
if (p.max_bias > 0.0f) { if (p.max_bias > 0.0f) {
@ -2472,12 +2470,19 @@ void main() {
} }
// Find max // Find max
vals[tid] = uintBitsToFloat(0xFF800000); FLOAT_TYPE max_val = uintBitsToFloat(0xFF800000);
[[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { [[unroll]] for (uint col0 = 0; col0 < p.KX; col0 += BLOCK_SIZE) {
vals[tid] = max(vals[tid], FLOAT_TYPE(data_a[rowx * p.KX + col]) * p.scale + (p.KY > 0 ? FLOAT_TYPE(data_b[rowy * p.KX + col]) : FLOAT_TYPE(0.0f)) + (p.KZ > 0 ? slope * FLOAT_TYPE(data_c[col]) : 0.0f)); const uint col = col0 + tid;
if (col >= p.KX) {
break;
} }
max_val = max(max_val, FLOAT_TYPE(data_a[rowx * p.KX + col]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy * p.KX + col]) : FLOAT_TYPE(0.0f)));
}
vals[tid] = max_val;
barrier(); barrier();
[[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) { if (tid < s) {
@ -2486,15 +2491,21 @@ void main() {
barrier(); barrier();
} }
const FLOAT_TYPE max_val = vals[0]; max_val = vals[0];
barrier(); barrier();
// Sum up values // Sum up values
vals[tid] = FLOAT_TYPE(0.0f); vals[tid] = FLOAT_TYPE(0.0f);
[[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { [[unroll]] for (uint col0 = 0; col0 < p.KX; col0 += BLOCK_SIZE) {
const uint col = col0 + tid;
if (col >= p.KX) {
break;
}
const uint i = rowx * p.KX + col; const uint i = rowx * p.KX + col;
const FLOAT_TYPE val = exp(FLOAT_TYPE(data_a[i]) * p.scale + (p.KY > 0 ? FLOAT_TYPE(data_b[rowy * p.KX + col]) : FLOAT_TYPE(0.0f)) - max_val); const FLOAT_TYPE val = exp(FLOAT_TYPE(data_a[i]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy * p.KX + col]) : FLOAT_TYPE(0.0f)) - max_val);
vals[tid] += val; vals[tid] += val;
data_d[i] = D_TYPE(val); data_d[i] = D_TYPE(val);
} }
@ -2509,7 +2520,13 @@ void main() {
const D_TYPE divisor = D_TYPE(vals[0]); const D_TYPE divisor = D_TYPE(vals[0]);
[[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { [[unroll]] for (uint col0 = 0; col0 < p.KX; col0 += BLOCK_SIZE) {
const uint col = col0 + tid;
if (col >= p.KX) {
break;
}
data_d[rowx*p.KX + col] /= divisor; data_d[rowx*p.KX + col] /= divisor;
} }
} }
@ -2672,20 +2689,26 @@ argsort_src = """
#extension GL_EXT_shader_16bit_storage : require #extension GL_EXT_shader_16bit_storage : require
layout(local_size_x = 1024, local_size_y = 1, local_size_z = 1) in; #define BLOCK_SIZE 1024
#define ASC 0
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) buffer D {int data_d[];}; layout (binding = 1) buffer D {int data_d[];};
layout (push_constant) uniform parameter { layout (push_constant) uniform parameter {
uint ncols; uint ncols;
bool ascending; uint ncols_pad;
uint order;
} p; } p;
shared int dst_row[BLOCK_SIZE];
void swap(uint idx0, uint idx1) { void swap(uint idx0, uint idx1) {
int tmp = data_d[idx0]; int tmp = dst_row[idx0];
data_d[idx0] = data_d[idx1]; dst_row[idx0] = dst_row[idx1];
data_d[idx1] = tmp; dst_row[idx1] = tmp;
} }
void main() { void main() {
@ -2693,36 +2716,45 @@ void main() {
const int col = int(gl_LocalInvocationID.x); const int col = int(gl_LocalInvocationID.x);
const uint row = gl_WorkGroupID.y; const uint row = gl_WorkGroupID.y;
if (col >= p.ncols) { if (col >= p.ncols_pad) {
return; return;
} }
const uint a_idx = row * p.ncols; const uint row_offset = row * p.ncols;
const uint d_idx = row * p.ncols;
// initialize indices // initialize indices
if (col < p.ncols) { dst_row[col] = col;
data_d[col] = col;
}
barrier(); barrier();
for (uint k = 2; k <= p.ncols; k *= 2) { for (uint k = 2; k <= p.ncols_pad; k *= 2) {
for (uint j = k / 2; j > 0; j /= 2) { for (uint j = k / 2; j > 0; j /= 2) {
const uint ixj = col ^ j; const uint ixj = col ^ j;
if (ixj > col) { if (ixj > col) {
if ((col & k) == 0) { if ((col & k) == 0) {
if (p.ascending ? data_a[a_idx + data_d[d_idx + col]] > data_a[a_idx + data_d[d_idx + ixj]] : data_a[a_idx + data_d[d_idx + col]] < data_a[a_idx + data_d[d_idx + ixj]]) { if (dst_row[col] >= p.ncols ||
swap(d_idx + col, d_idx + ixj); (dst_row[ixj] < p.ncols && (p.order == ASC ?
data_a[row_offset + dst_row[col]] > data_a[row_offset + dst_row[ixj]] :
data_a[row_offset + dst_row[col]] < data_a[row_offset + dst_row[ixj]]))
) {
swap(col, ixj);
} }
} else { } else {
if (p.ascending ? data_a[a_idx + data_d[d_idx + col]] < data_a[a_idx + data_d[d_idx + ixj]] : data_a[a_idx + data_d[d_idx + col]] > data_a[a_idx + data_d[d_idx + ixj]]) { if (dst_row[ixj] >= p.ncols ||
swap(d_idx + col, d_idx + ixj); (dst_row[col] < p.ncols && (p.order == ASC ?
data_a[row_offset + dst_row[col]] < data_a[row_offset + dst_row[ixj]] :
data_a[row_offset + dst_row[col]] > data_a[row_offset + dst_row[ixj]]))
) {
swap(col, ixj);
} }
} }
} }
barrier(); barrier();
} }
} }
if (col < p.ncols) {
data_d[row_offset + col] = dst_row[col];
}
} }
""" """

View file

@ -7,7 +7,7 @@ Just copy this single static HTML file anywhere and open it in a browser, or fro
Please go to https://github.com/LostRuins/lite.koboldai.net for updates on Kobold Lite. Please go to https://github.com/LostRuins/lite.koboldai.net for updates on Kobold Lite.
If you are submitting a pull request for Lite, PLEASE use the above repo, not the KoboldCpp one. If you are submitting a pull request for Lite, PLEASE use the above repo, not the KoboldCpp one.
Kobold Lite is under the AGPL v3.0 License unless otherwise exempted. Please do not remove this line. Kobold Lite is under the AGPL v3.0 License unless otherwise exempted. Please do not remove this line.
Current version: 138 Current version: 139
-Concedo -Concedo
--> -->
@ -66,6 +66,33 @@ Current version: 138
filter: invert(1); filter: invert(1);
} }
.maincontainer {
padding-right: 4px;
padding-left: 4px;
margin-right: auto;
margin-left: auto;
}
@media (min-width: 768px) {
.adaptivecontainer {
width: 750px;
}}
@media (min-width: 992px) {
.adaptivecontainer {
width: 970px;
}}
@media (min-width: 1200px) {
.adaptivecontainer {
width: 1170px;
}}
@media (min-width: 1200px) {
.clampedcontainer {
width: 1170px;
}}
.settinglabel input { .settinglabel input {
width: 6ch; width: 6ch;
background-color: #1a3364; background-color: #1a3364;
@ -1234,13 +1261,6 @@ Current version: 138
padding: 2px 6px; padding: 2px 6px;
} }
.maincontainer {
padding-right: 4px;
padding-left: 4px;
margin-right: auto;
margin-left: auto;
}
.workerTableDiv,.shareStory{ .workerTableDiv,.shareStory{
max-height: 420px; max-height: 420px;
overflow-y: auto; overflow-y: auto;
@ -1951,16 +1971,16 @@ Current version: 138
{ {
height: calc(98vh - 240px); height: calc(98vh - 240px);
} }
@media (max-width: 598px) { @media (max-width: 534px) {
.normal_viewport_height .normal_viewport_height
{ {
height: calc(98vh - 270px); height: calc(98vh - 260px);
} }
} }
@media (max-width: 342px) { @media (max-width: 342px) {
.normal_viewport_height .normal_viewport_height
{ {
height: calc(98vh - 300px); height: calc(98vh - 280px);
} }
} }
@media print { @media print {
@ -3705,6 +3725,7 @@ Current version: 138
autoscroll: true, //automatically scroll to bottom on render autoscroll: true, //automatically scroll to bottom on render
printer_view: false, //automatically scroll to bottom on render printer_view: false, //automatically scroll to bottom on render
viewport_width_mode: 0, //0=adapt, 1=clamp, 2=unlock
trimsentences: true, //trim to last punctuation trimsentences: true, //trim to last punctuation
trimwhitespace: false, //trim trailing whitespace trimwhitespace: false, //trim trailing whitespace
compressnewlines: false, //compress multiple newlines compressnewlines: false, //compress multiple newlines
@ -8411,6 +8432,7 @@ Current version: 138
document.getElementById("top_p").value = document.getElementById("top_p_slide").value = localsettings.top_p; document.getElementById("top_p").value = document.getElementById("top_p_slide").value = localsettings.top_p;
document.getElementById("autoscroll").checked = localsettings.autoscroll; document.getElementById("autoscroll").checked = localsettings.autoscroll;
document.getElementById("printer_view").checked = localsettings.printer_view; document.getElementById("printer_view").checked = localsettings.printer_view;
document.getElementById("viewport_width_mode").value = localsettings.viewport_width_mode;
document.getElementById("export_settings").checked = localsettings.export_settings; document.getElementById("export_settings").checked = localsettings.export_settings;
document.getElementById("show_advanced_load").checked = localsettings.show_advanced_load; document.getElementById("show_advanced_load").checked = localsettings.show_advanced_load;
document.getElementById("invert_colors").checked = localsettings.invert_colors; document.getElementById("invert_colors").checked = localsettings.invert_colors;
@ -8715,6 +8737,7 @@ Current version: 138
localsettings.top_p = document.getElementById("top_p").value; localsettings.top_p = document.getElementById("top_p").value;
localsettings.autoscroll = (document.getElementById("autoscroll").checked ? true : false); localsettings.autoscroll = (document.getElementById("autoscroll").checked ? true : false);
localsettings.printer_view = (document.getElementById("printer_view").checked ? true : false); localsettings.printer_view = (document.getElementById("printer_view").checked ? true : false);
localsettings.viewport_width_mode = document.getElementById("viewport_width_mode").value;
localsettings.export_settings = (document.getElementById("export_settings").checked ? true : false); localsettings.export_settings = (document.getElementById("export_settings").checked ? true : false);
localsettings.show_advanced_load = (document.getElementById("show_advanced_load").checked ? true : false); localsettings.show_advanced_load = (document.getElementById("show_advanced_load").checked ? true : false);
localsettings.invert_colors = (document.getElementById("invert_colors").checked ? true : false); localsettings.invert_colors = (document.getElementById("invert_colors").checked ? true : false);
@ -13223,6 +13246,23 @@ Current version: 138
document.getElementById("chat_msg_body").classList.add("aesthetic_viewport_height"); document.getElementById("chat_msg_body").classList.add("aesthetic_viewport_height");
} }
let maincon = document.getElementById("maincontainer");
if(localsettings.viewport_width_mode==1) //clamped
{
maincon.classList.remove("adaptivecontainer");
maincon.classList.add("clampedcontainer");
}
else if(localsettings.viewport_width_mode==2) //unlock
{
maincon.classList.remove("adaptivecontainer");
maincon.classList.remove("clampedcontainer");
}
else //adaptive
{
maincon.classList.add("adaptivecontainer");
maincon.classList.remove("clampedcontainer");
}
update_genimg_button_visiblility(); update_genimg_button_visiblility();
idle_timer = 0; idle_timer = 0;
@ -14590,7 +14630,7 @@ Current version: 138
<body id="outerbody" class=""> <body id="outerbody" class="">
<div class="container maincontainer"> <div id="maincontainer" class="adaptivecontainer maincontainer">
<div id="outerbodybg"></div> <div id="outerbodybg"></div>
<div class="" id="topmenu"> <div class="" id="topmenu">
<div id="menuitems"> <div id="menuitems">
@ -15683,6 +15723,15 @@ Current version: 138
class="helptext">Unlocks the text viewport, allowing for infinite height without scrolling</span></span></div> class="helptext">Unlocks the text viewport, allowing for infinite height without scrolling</span></span></div>
<input type="checkbox" id="printer_view" style="margin:0px 0px 0px auto;"> <input type="checkbox" id="printer_view" style="margin:0px 0px 0px auto;">
</div> </div>
<div class="settinglabel">
<div class="justifyleft settingsmall" title="">Viewport Width <span class="helpicon">?<span
class="helptext">Controls horizontal scaling of the viewport window</span></span></div>
<select style="padding:1px; height:auto; width: 34px; appearance: none; font-size: 7pt; margin:0px 0px 0px auto;" class="form-control" id="viewport_width_mode">
<option value="0">Adapt</option>
<option value="1">Clamp</option>
<option value="2">Unlock</option>
</select>
</div>
<div class="settinglabel"> <div class="settinglabel">
<div class="justifyleft settingsmall">Inverted Colors <span class="helpicon">?<span <div class="justifyleft settingsmall">Inverted Colors <span class="helpicon">?<span
class="helptext">Inverts all colors, simple light mode</span></span></div> class="helptext">Inverts all colors, simple light mode</span></span></div>

View file

@ -1723,6 +1723,8 @@ struct llama_state {
llama_state() { llama_state() {
#ifdef GGML_USE_METAL #ifdef GGML_USE_METAL
ggml_backend_metal_log_set_callback(log_callback, log_callback_user_data); ggml_backend_metal_log_set_callback(log_callback, log_callback_user_data);
#elif defined(GGML_USE_CUDA)
ggml_backend_cuda_log_set_callback(log_callback, log_callback_user_data);
#endif #endif
} }
@ -5258,7 +5260,14 @@ static bool llm_load_tensors(
{ {
model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd});
model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, false);
if (!model.output) {
// needs to be on GPU
model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
ml.n_created--; // artificial tensor
ml.size_data += ggml_nbytes(model.output);
}
} }
for (int i = 0; i < n_layer; ++i) { for (int i = 0; i < n_layer; ++i) {
@ -12867,16 +12876,16 @@ struct llm_tokenizer_wpm {
// to lowercase, pad chinese characters, pad punctuation // to lowercase, pad chinese characters, pad punctuation
std::string new_str = ""; std::string new_str = "";
for (uint32_t code : cpts_nfd) { for (uint32_t code : cpts_nfd) {
int type = unicode_cpt_type(code); const codepoint_flags flags = unicode_cpt_flags(code);
if (type == CODEPOINT_TYPE_ACCENT_MARK || type == CODEPOINT_TYPE_CONTROL) { if (flags.is_accent_mark || flags.is_control) {
continue; continue;
} }
code = unicode_tolower(code); code = unicode_tolower(code);
if (type == CODEPOINT_TYPE_SEPARATOR) { if (flags.is_separator || flags.is_whitespace) { //####FIXME: is_separator ?
code = ' '; code = ' ';
} }
std::string s = unicode_cpt_to_utf8(code); std::string s = unicode_cpt_to_utf8(code);
if (type == CODEPOINT_TYPE_PUNCTUATION || is_ascii_punct(code) || is_chinese_char(code)) { if (flags.is_punctuation || is_ascii_punct(code) || is_chinese_char(code)) {
new_str += " "; new_str += " ";
new_str += s; new_str += s;
new_str += " "; new_str += " ";
@ -18485,6 +18494,8 @@ void llama_log_set(ggml_log_callback log_callback, void * user_data) {
g_state.log_callback_user_data = user_data; g_state.log_callback_user_data = user_data;
#ifdef GGML_USE_METAL #ifdef GGML_USE_METAL
ggml_backend_metal_log_set_callback(g_state.log_callback, g_state.log_callback_user_data); ggml_backend_metal_log_set_callback(g_state.log_callback, g_state.log_callback_user_data);
#elif defined(GGML_USE_CUDA)
ggml_backend_cuda_log_set_callback(g_state.log_callback, g_state.log_callback_user_data);
#endif #endif
} }

View file

@ -1,5 +1,5 @@
numpy==1.24.4 numpy>=1.24.4
sentencepiece==0.1.98 sentencepiece>=0.1.98
transformers>=4.34.0 transformers>=4.34.0
gguf>=0.1.0 gguf>=0.1.0
customtkinter>=5.1.0 customtkinter>=5.1.0

View file

@ -1,64 +1,134 @@
import regex import regex
import ctypes
import unicodedata
def get_matches(regex_expr): class CoodepointFlags (ctypes.Structure):
regex_expr_compiled = regex.compile(regex_expr) _fields_ = [ # see definition in unicode.h
unicode_ranges = [] ("is_undefined", ctypes.c_uint16, 1),
current_range = None ("is_number", ctypes.c_uint16, 1), # regex: \p{N}
("is_letter", ctypes.c_uint16, 1), # regex: \p{L}
("is_separator", ctypes.c_uint16, 1), # regex: \p{Z}
("is_accent_mark", ctypes.c_uint16, 1), # regex: \p{M}
("is_punctuation", ctypes.c_uint16, 1), # regex: \p{P}
("is_symbol", ctypes.c_uint16, 1), # regex: \p{S}
("is_control", ctypes.c_uint16, 1), # regex: \p{C}
]
for codepoint in range(0x110000):
assert (ctypes.sizeof(CoodepointFlags) == 2)
MAX_CODEPOINTS = 0x110000
regex_number = regex.compile(r'\p{N}')
regex_letter = regex.compile(r'\p{L}')
regex_separator = regex.compile(r'\p{Z}')
regex_accent_mark = regex.compile(r'\p{M}')
regex_punctuation = regex.compile(r'\p{P}')
regex_symbol = regex.compile(r'\p{S}')
regex_control = regex.compile(r'\p{C}')
regex_whitespace = regex.compile(r'\s')
codepoint_flags = (CoodepointFlags * MAX_CODEPOINTS)()
table_whitespace = []
table_lowercase = []
table_uppercase = []
table_nfd = []
for codepoint in range(MAX_CODEPOINTS):
# convert codepoint to unicode character
char = chr(codepoint) char = chr(codepoint)
if regex_expr_compiled.match(char):
if current_range is None:
current_range = [codepoint, codepoint]
else:
current_range[1] = codepoint
elif current_range is not None:
unicode_ranges.append(tuple(current_range))
current_range = None
if current_range is not None: # regex categories
unicode_ranges.append(tuple(current_range)) flags = codepoint_flags[codepoint]
flags.is_number = bool(regex_number.match(char))
flags.is_letter = bool(regex_letter.match(char))
flags.is_separator = bool(regex_separator.match(char))
flags.is_accent_mark = bool(regex_accent_mark.match(char))
flags.is_punctuation = bool(regex_punctuation.match(char))
flags.is_symbol = bool(regex_symbol.match(char))
flags.is_control = bool(regex_control.match(char))
flags.is_undefined = bytes(flags)[0] == 0
assert (not flags.is_undefined)
return unicode_ranges # whitespaces
if bool(regex_whitespace.match(char)):
table_whitespace.append(codepoint)
# lowercase conversion
def print_cat(mode, cat, ranges):
if mode == "range":
print("const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_{} = {{".format(cat)) # noqa: NP100
if mode == "map":
print("const std::map<uint32_t, uint32_t> unicode_map_{} = {{".format(cat)) # noqa: NP100
for i, values in enumerate(ranges):
end = ",\n" if (i % 4 == 3 or i + 1 == len(ranges)) else ", "
values = ["0x%08X" % value for value in values]
print("{" + ", ".join(values) + "}", end=end) # noqa: NP100
print("};") # noqa: NP100
print("") # noqa: NP100
print_cat("range", "number", get_matches(r'\p{N}'))
print_cat("range", "letter", get_matches(r'\p{L}'))
print_cat("range", "separator", get_matches(r'\p{Z}'))
print_cat("range", "accent_mark", get_matches(r'\p{M}'))
print_cat("range", "punctuation", get_matches(r'\p{P}'))
print_cat("range", "symbol", get_matches(r'\p{S}'))
print_cat("range", "control", get_matches(r'\p{C}'))
print_cat("range", "whitespace", get_matches(r'\s'))
map_lowercase = []
map_uppercase = []
for codepoint in range(0x110000):
char = chr(codepoint)
lower = ord(char.lower()[0]) lower = ord(char.lower()[0])
upper = ord(char.upper()[0])
if codepoint != lower: if codepoint != lower:
map_lowercase.append((codepoint, lower)) table_lowercase.append((codepoint, lower))
# uppercase conversion
upper = ord(char.upper()[0])
if codepoint != upper: if codepoint != upper:
map_uppercase.append((codepoint, upper)) table_uppercase.append((codepoint, upper))
print_cat("map", "lowercase", map_lowercase)
print_cat("map", "uppercase", map_uppercase) # NFD normalization
norm = ord(unicodedata.normalize('NFD', char)[0])
if codepoint != norm:
table_nfd.append((codepoint, norm))
# TODO: generate unicode_map_nfd # group ranges with same flags
ranges_flags = [(0, codepoint_flags[0])] # start, flags
for codepoint, flags in enumerate(codepoint_flags):
if bytes(flags) != bytes(ranges_flags[-1][1]):
ranges_flags.append((codepoint, flags))
ranges_flags.append((MAX_CODEPOINTS, CoodepointFlags()))
# group ranges with same nfd
ranges_nfd = [(0, 0, 0)] # start, last, nfd
for codepoint, norm in table_nfd:
start = ranges_nfd[-1][0]
if ranges_nfd[-1] != (start, codepoint - 1, norm):
ranges_nfd.append(None)
start = codepoint
ranges_nfd[-1] = (start, codepoint, norm)
# Generate 'unicode-data.cpp'
def out(line=""):
print(line, end='\n') # noqa
out("""\
// generated with scripts/gen-unicode-data.py
#include "unicode-data.h"
#include <cstdint>
#include <vector>
#include <unordered_map>
#include <unordered_set>
""")
out("const std::vector<std::pair<uint32_t, uint16_t>> unicode_ranges_flags = { // start, flags // last=next_start-1")
for codepoint, flags in ranges_flags:
flags = int.from_bytes(bytes(flags), "little")
out("{0x%06X, 0x%04X}," % (codepoint, flags))
out("};\n")
out("const std::unordered_set<uint32_t> unicode_set_whitespace = {")
out(", ".join("0x%06X" % cpt for cpt in table_whitespace))
out("};\n")
out("const std::unordered_map<uint32_t, uint32_t> unicode_map_lowercase = {")
for tuple in table_lowercase:
out("{0x%06X, 0x%06X}," % tuple)
out("};\n")
out("const std::unordered_map<uint32_t, uint32_t> unicode_map_uppercase = {")
for tuple in table_uppercase:
out("{0x%06X, 0x%06X}," % tuple)
out("};\n")
out("const std::vector<range_nfd> unicode_ranges_nfd = { // start, last, nfd")
for triple in ranges_nfd:
out("{0x%06X, 0x%06X, 0x%06X}," % triple)
out("};\n")

View file

@ -1,5 +1,5 @@
# Test libllama tokenizer == AutoTokenizer. # Test libllama tokenizer == AutoTokenizer.
# Brute force random tokens/text generation. # Brute force random words/text generation.
# #
# Sample usage: # Sample usage:
# #
@ -12,10 +12,10 @@ import argparse
import subprocess import subprocess
import random import random
from typing import Iterator from typing import Callable, Iterator
import cffi import cffi
from transformers import AutoTokenizer, PreTrainedTokenizerBase from transformers import AutoTokenizer
logger = logging.getLogger("test-tokenizer-random-bpe") logger = logging.getLogger("test-tokenizer-random-bpe")
@ -152,21 +152,28 @@ def generator_custom_text_edge_cases() -> Iterator[str]:
'a b', # unicode_ranges_digit, 0x3007 'a b', # unicode_ranges_digit, 0x3007
'Ⅵ-a', # unicode_ranges_digit, {0x00002150, 0x0000218F} // Number Forms 'Ⅵ-a', # unicode_ranges_digit, {0x00002150, 0x0000218F} // Number Forms
'\uFEFF//', # unicode_ranges_control, 0xFEFF (BOM) '\uFEFF//', # unicode_ranges_control, 0xFEFF (BOM)
'<s>a' # TODO: Phi-3 fail 'Cửa Việt', # llama-3, ignore_merges = true
'<s>a', # TODO: Phi-3 fail
'a\na', # TODO: Bert fail
] ]
def generator_vocab_words(vocab: list[str]) -> Iterator[str]:
"""Brute force check all vocab words"""
yield from vocab
def generator_random_chars(iterations=100) -> Iterator[str]: def generator_random_chars(iterations=100) -> Iterator[str]:
"""Brute force random text with simple characters""" """Brute force random text with simple characters"""
WHITESPACES = list(" " * 20 + "\n" * 5 + "\r\n" * 5 + "\t" * 5) WHITESPACES = list(" " * 20 + "\n" * 5 + "\r\n" * 5 + "\t" * 5)
CHARS = list(set(""" CHARS = list(sorted(set("""
ABCDEFGHIJKLMNOPQRSTUVWXYZ ABCDEFGHIJKLMNOPQRSTUVWXYZ
abcdefghijklmnopqrstuvwxyz abcdefghijklmnopqrstuvwxyz
ÁÉÍÓÚÀÈÌÒÙÂÊÎÔÛÄËÏÖÜ ÁÉÍÓÚÀÈÌÒÙÂÊÎÔÛÄËÏÖÜ
áéíóúàèìòùâêîôûäëïöü áéíóúàèìòùâêîôûäëïöü
.-,*/-+ª!"·$%&/()=?¿[]{}<>\\|@#~½¬~;:_ .-,*/-+ª!"·$%&/()=?¿[]{}<>\\|@#~½¬~;:_
""")) """)))
rand = random.Random() rand = random.Random()
for m in range(iterations): for m in range(iterations):
@ -181,13 +188,13 @@ def generator_random_chars(iterations = 100) -> Iterator[str]:
yield "".join(text) yield "".join(text)
def generator_random_vocab_chars(tokenizer: PreTrainedTokenizerBase, iterations = 100) -> Iterator[str]: def generator_random_vocab_chars(vocab: list[str], iterations=100) -> Iterator[str]:
"""Brute force random text with vocab characters""" """Brute force random text with vocab characters"""
vocab_ids = list(tokenizer.vocab.values()) vocab_chars = set()
vocab_text = tokenizer.decode(vocab_ids, skip_special_tokens=True) for word in vocab:
vocab_chars = list(set(vocab_text)) vocab_chars.update(word)
del vocab_ids, vocab_text vocab_chars = list(sorted(vocab_chars))
rand = random.Random() rand = random.Random()
for m in range(iterations): for m in range(iterations):
@ -196,19 +203,11 @@ def generator_random_vocab_chars(tokenizer: PreTrainedTokenizerBase, iterations
yield "".join(text) yield "".join(text)
def generator_random_vocab_tokens(tokenizer: PreTrainedTokenizerBase, iterations = 100) -> Iterator[str]: def generator_random_vocab_words(vocab: list[str], iterations=100) -> Iterator[str]:
"""Brute force random text from vocab tokens""" """Brute force random text from vocab words"""
space_id = tokenizer.encode(" ", add_special_tokens=False)[0] vocab = [w.strip() for w in vocab]
vocab_ids = list(tokenizer.vocab.values()) yield from vocab
vocab_ids = list(sorted(vocab_ids + vocab_ids))
for i in range(1, len(vocab_ids), 2):
vocab_ids[i] = space_id
vocab_tokens = tokenizer.decode(vocab_ids, skip_special_tokens=True)
vocab_tokens = vocab_tokens.split(" ")
del vocab_ids
yield from vocab_tokens
rand = random.Random() rand = random.Random()
for m in range(iterations): for m in range(iterations):
@ -217,10 +216,9 @@ def generator_random_vocab_tokens(tokenizer: PreTrainedTokenizerBase, iterations
num_words = rand.randint(300, 400) num_words = rand.randint(300, 400)
for i in range(num_words): for i in range(num_words):
k = rand.randint(1, 3) k = rand.randint(1, 3)
tokens = rand.choices(vocab_tokens, k=k) words = rand.choices(vocab, k=k)
tokens = [t.strip(" \n\r\t") for t in tokens]
sep = rand.choice(" \n\r\t") sep = rand.choice(" \n\r\t")
text.append("".join(tokens) + sep) text.append("".join(words) + sep)
yield "".join(text) yield "".join(text)
@ -242,7 +240,7 @@ def generator_random_bytes(iterations = 100) -> Iterator[str]:
yield "".join(text) yield "".join(text)
def test_compare_tokenizer(model: LibLlamaModel, tokenizer: PreTrainedTokenizerBase, generator: Iterator[str]): def test_compare_tokenizer(func_tokenize1: Callable, func_tokenize2: Callable, generator: Iterator[str]):
def find_first_mismatch(ids1: list[int], ids2: list[int]): def find_first_mismatch(ids1: list[int], ids2: list[int]):
for i, (a, b) in enumerate(zip(ids1, ids2)): for i, (a, b) in enumerate(zip(ids1, ids2)):
@ -255,15 +253,12 @@ def test_compare_tokenizer(model: LibLlamaModel, tokenizer: PreTrainedTokenizerB
t0 = time.perf_counter() t0 = time.perf_counter()
logger.info("%s: %s" % (generator.__name__, "ini")) logger.info("%s: %s" % (generator.__name__, "ini"))
for text in generator: for text in generator:
ids1 = model.tokenize(text, add_special=False, parse_special=False) ids1 = func_tokenize1(text)
ids2 = tokenizer.encode(text, add_special_tokens=False) ids2 = func_tokenize2(text)
if ids1 != ids2: if ids1 != ids2:
i = find_first_mismatch(ids1, ids2) i = find_first_mismatch(ids1, ids2)
ids1 = list(ids1)[max(0, i - 2) : i + 2 + 1] ids1 = list(ids1)[max(0, i - 2) : i + 2 + 1]
ids2 = list(ids2)[max(0, i - 2) : i + 2 + 1] ids2 = list(ids2)[max(0, i - 2) : i + 2 + 1]
text2 = tokenizer.decode(ids2, skip_special_tokens=True)
assert (text2 in text)
logger.info(" Text: " + repr(text2))
logger.info(" TokenIDs: " + str(ids1)) logger.info(" TokenIDs: " + str(ids1))
logger.info(" Expected: " + str(ids2)) logger.info(" Expected: " + str(ids2))
raise Exception() raise Exception()
@ -271,25 +266,37 @@ def test_compare_tokenizer(model: LibLlamaModel, tokenizer: PreTrainedTokenizerB
logger.info("%s: end, time: %.3f secs" % (generator.__name__, t1 - t0)) logger.info("%s: end, time: %.3f secs" % (generator.__name__, t1 - t0))
if __name__ == "__main__": def main(argv: list[str] = None):
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("vocab_file", help="path to vocab 'gguf' file") parser.add_argument("vocab_file", help="path to vocab 'gguf' file")
parser.add_argument("dir_tokenizer", help="directory containing 'tokenizer.model' file") parser.add_argument("dir_tokenizer", help="directory containing 'tokenizer.model' file")
parser.add_argument("--verbose", action="store_true", help="increase output verbosity") parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
args = parser.parse_args() args = parser.parse_args(argv)
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO) logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
model = LibLlamaModel(LibLlama(), args.vocab_file, mparams=dict(vocab_only=True), cparams=dict(n_ctx=2048)) model = LibLlamaModel(LibLlama(), args.vocab_file, mparams=dict(vocab_only=True), cparams=dict(n_ctx=4096))
tokenizer = AutoTokenizer.from_pretrained(args.dir_tokenizer) tokenizer = AutoTokenizer.from_pretrained(args.dir_tokenizer)
test_compare_tokenizer(model, tokenizer, generator_custom_text()) def func_tokenize2(text: str):
test_compare_tokenizer(model, tokenizer, generator_custom_text_edge_cases()) return tokenizer.encode(text, add_special_tokens=False)
test_compare_tokenizer(model, tokenizer, generator_random_chars(10_000))
test_compare_tokenizer(model, tokenizer, generator_random_vocab_chars(tokenizer, 10_000)) parse_special = all(len(func_tokenize2(t)) == 1 for t in tokenizer.all_special_tokens)
test_compare_tokenizer(model, tokenizer, generator_random_vocab_tokens(tokenizer, 10_000))
# test_compare_tokenizer(model, tokenizer, generator_random_bytes(10_000)) # FAIL def func_tokenize1(text: str):
return model.tokenize(text, add_special=False, parse_special=parse_special)
vocab = list(sorted(tokenizer.batch_decode(list(tokenizer.get_vocab().values()), skip_special_tokens=True)))
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_custom_text())
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_custom_text_edge_cases())
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_vocab_words(vocab))
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_chars(10_000))
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_vocab_chars(vocab, 10_000))
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_vocab_words(vocab, 10_000))
# test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_bytes(10_000)) # FAIL
model.free() model.free()
if __name__ == "__main__":
main()

File diff suppressed because it is too large Load diff

View file

@ -1,17 +1,20 @@
#pragma once #pragma once
#include <cstdint> #include <cstdint>
#include <map>
#include <utility>
#include <vector> #include <vector>
#include <unordered_map>
#include <unordered_set>
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_number; struct range_nfd {
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_letter; uint32_t first;
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_separator; uint32_t last;
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_whitespace; uint32_t nfd;
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_accent_mark; };
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_punctuation;
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_symbol; static const uint32_t MAX_CODEPOINTS = 0x110000;
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_control;
extern const std::multimap<uint32_t, uint32_t> unicode_map_nfd; extern const std::vector<std::pair<uint32_t, uint16_t>> unicode_ranges_flags;
extern const std::map<char32_t, char32_t> unicode_map_lowercase; extern const std::unordered_set<uint32_t> unicode_set_whitespace;
extern const std::unordered_map<uint32_t, uint32_t> unicode_map_lowercase;
extern const std::unordered_map<uint32_t, uint32_t> unicode_map_uppercase;
extern const std::vector<range_nfd> unicode_ranges_nfd;

View file

@ -1,4 +1,4 @@
#include "unicode.h" #include "unicode.h"
#include "unicode-data.h" #include "unicode-data.h"
#include <cassert> #include <cassert>
@ -109,57 +109,49 @@ static uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset)
// return result; // return result;
//} //}
static std::unordered_map<uint32_t, int> unicode_cpt_type_map() { static std::vector<codepoint_flags> unicode_cpt_flags_array() {
std::unordered_map<uint32_t, int> cpt_types; std::vector<codepoint_flags> cpt_flags(MAX_CODEPOINTS, codepoint_flags::UNDEFINED);
for (auto p : unicode_ranges_number) {
for (auto i = p.first; i <= p.second; ++i) { assert (unicode_ranges_flags.front().first == 0);
cpt_types[i] = CODEPOINT_TYPE_NUMBER; assert (unicode_ranges_flags.back().first == MAX_CODEPOINTS);
for (size_t i = 1; i < unicode_ranges_flags.size(); ++i) {
const auto range_ini = unicode_ranges_flags[i-1]; // codepoint_ini, flags
const auto range_end = unicode_ranges_flags[i]; // codepoint_end, flags
for (uint32_t cpt = range_ini.first; cpt < range_end.first; ++cpt) {
cpt_flags[cpt] = range_ini.second;
} }
} }
for (auto p : unicode_ranges_letter) {
for (auto i = p.first; i <= p.second; ++i) { for (auto cpt : unicode_set_whitespace) {
cpt_types[i] = CODEPOINT_TYPE_LETTER; cpt_flags[cpt].is_whitespace = true;
} }
for (auto p : unicode_map_lowercase) {
cpt_flags[p.second].is_lowercase = true;
} }
for (auto p : unicode_ranges_separator) {
for (auto i = p.first; i <= p.second; ++i) { for (auto p : unicode_map_uppercase) {
cpt_types[i] = CODEPOINT_TYPE_SEPARATOR; cpt_flags[p.second].is_uppercase = true;
} }
for (auto &range : unicode_ranges_nfd) { // start, last, nfd
cpt_flags[range.nfd].is_nfd = true;
} }
for (auto p : unicode_ranges_accent_mark) {
for (auto i = p.first; i <= p.second; ++i) { return cpt_flags;
cpt_types[i] = CODEPOINT_TYPE_ACCENT_MARK;
}
}
for (auto p : unicode_ranges_punctuation) {
for (auto i = p.first; i <= p.second; ++i) {
cpt_types[i] = CODEPOINT_TYPE_PUNCTUATION;
}
}
for (auto p : unicode_ranges_symbol) {
for (auto i = p.first; i <= p.second; ++i) {
cpt_types[i] = CODEPOINT_TYPE_SYMBOL;
}
}
for (auto p : unicode_ranges_control) {
for (auto i = p.first; i <= p.second; ++i) {
cpt_types[i] = CODEPOINT_TYPE_CONTROL;
}
}
return cpt_types;
} }
static std::unordered_map<uint8_t, std::string> unicode_byte_to_utf8_map() { static std::unordered_map<uint8_t, std::string> unicode_byte_to_utf8_map() {
std::unordered_map<uint8_t, std::string> map; std::unordered_map<uint8_t, std::string> map;
for (int ch = u'!'; ch <= u'~'; ++ch) { for (int ch = 0x21; ch <= 0x7E; ++ch) { // u'!' to u'~'
assert(0 <= ch && ch < 256); assert(0 <= ch && ch < 256);
map[ch] = unicode_cpt_to_utf8(ch); map[ch] = unicode_cpt_to_utf8(ch);
} }
for (int ch = u'¡'; ch <= u'¬'; ++ch) { for (int ch = 0xA1; ch <= 0xAC; ++ch) { // u'¡' to u'¬'
assert(0 <= ch && ch < 256); assert(0 <= ch && ch < 256);
map[ch] = unicode_cpt_to_utf8(ch); map[ch] = unicode_cpt_to_utf8(ch);
} }
for (int ch = u'®'; ch <= u'ÿ'; ++ch) { for (int ch = 0xAE; ch <= 0xFF; ++ch) { // u'®' to u'ÿ'
assert(0 <= ch && ch < 256); assert(0 <= ch && ch < 256);
map[ch] = unicode_cpt_to_utf8(ch); map[ch] = unicode_cpt_to_utf8(ch);
} }
@ -175,15 +167,15 @@ static std::unordered_map<uint8_t, std::string> unicode_byte_to_utf8_map() {
static std::unordered_map<std::string, uint8_t> unicode_utf8_to_byte_map() { static std::unordered_map<std::string, uint8_t> unicode_utf8_to_byte_map() {
std::unordered_map<std::string, uint8_t> map; std::unordered_map<std::string, uint8_t> map;
for (int ch = u'!'; ch <= u'~'; ++ch) { for (int ch = 0x21; ch <= 0x7E; ++ch) { // u'!' to u'~'
assert(0 <= ch && ch < 256); assert(0 <= ch && ch < 256);
map[unicode_cpt_to_utf8(ch)] = ch; map[unicode_cpt_to_utf8(ch)] = ch;
} }
for (int ch = u'¡'; ch <= u'¬'; ++ch) { for (int ch = 0xA1; ch <= 0xAC; ++ch) { // u'¡' to u'¬'
assert(0 <= ch && ch < 256); assert(0 <= ch && ch < 256);
map[unicode_cpt_to_utf8(ch)] = ch; map[unicode_cpt_to_utf8(ch)] = ch;
} }
for (int ch = u'®'; ch <= u'ÿ'; ++ch) { for (int ch = 0xAE; ch <= 0xFF; ++ch) { // u'®' to u'ÿ'
assert(0 <= ch && ch < 256); assert(0 <= ch && ch < 256);
map[unicode_cpt_to_utf8(ch)] = ch; map[unicode_cpt_to_utf8(ch)] = ch;
} }
@ -238,8 +230,9 @@ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string & t
return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : 0; return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : 0;
}; };
auto _get_cpt_type = [&] (const size_t pos) -> int { auto _get_flags = [&] (const size_t pos) -> codepoint_flags {
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_type(cpts[pos]) : CODEPOINT_TYPE_UNIDENTIFIED; static const codepoint_flags undef(codepoint_flags::UNDEFINED);
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : undef;
}; };
size_t _prev_end = offset_ini; size_t _prev_end = offset_ini;
@ -261,7 +254,7 @@ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string & t
for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) { for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) {
const char32_t cpt = _get_cpt(pos); const char32_t cpt = _get_cpt(pos);
const int cpt_type = _get_cpt_type(pos); const auto flags = _get_flags(pos);
// regex: 's|'t|'re|'ve|'m|'ll|'d // regex: 's|'t|'re|'ve|'m|'ll|'d
if (cpt == '\'' && pos+1 < offset_end) { if (cpt == '\'' && pos+1 < offset_end) {
@ -281,39 +274,37 @@ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string & t
} }
} }
char32_t cpt2 = (cpt == ' ' ? _get_cpt(pos+1) : cpt); auto flags2 = (cpt == ' ' ? _get_flags(pos+1) : flags);
int cpt2_type = (cpt == ' ' ? _get_cpt_type(pos+1) : cpt_type);
// regex: <space>?\p{L}+ // regex: <space>?\p{L}+
if (cpt2_type == CODEPOINT_TYPE_LETTER) { if (flags2.is_letter) {
pos += (cpt == ' '); pos += (cpt == ' ');
while (cpt2_type == CODEPOINT_TYPE_LETTER) { while (flags2.is_letter) {
cpt2_type = _get_cpt_type(++pos); flags2 = _get_flags(++pos);
} }
_add_token(pos); _add_token(pos);
continue; continue;
} }
// regex: <space>?\p{N}+ // regex: <space>?\p{N}+
if (cpt2_type == CODEPOINT_TYPE_NUMBER) { if (flags2.is_number) {
pos += (cpt == ' '); pos += (cpt == ' ');
while (cpt2_type == CODEPOINT_TYPE_NUMBER) { while (flags2.is_number) {
cpt2_type = _get_cpt_type(++pos); flags2 = _get_flags(++pos);
} }
_add_token(pos); _add_token(pos);
continue; continue;
} }
// regex: <space>?[^\s\p{L}\p{N}]+ // regex: <space>?[^\s\p{L}\p{N}]+
if (!unicode_cpt_is_whitespace(cpt2) && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_NUMBER && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED) { if (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number || flags2.is_undefined)) {
pos += (cpt == ' '); pos += (cpt == ' ');
while (!unicode_cpt_is_whitespace(cpt2) && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_NUMBER && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED) { while (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number || flags2.is_undefined)) {
cpt2_type = _get_cpt_type(++pos); flags2 = _get_flags(++pos);
cpt2 = _get_cpt(pos);
} }
_add_token(pos); _add_token(pos);
continue; continue;
} }
size_t num_whitespaces = 0; size_t num_whitespaces = 0;
while (unicode_cpt_is_whitespace(_get_cpt(pos+num_whitespaces))) { while (_get_flags(pos+num_whitespaces).is_whitespace) {
num_whitespaces++; num_whitespaces++;
} }
@ -357,8 +348,9 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : 0; return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : 0;
}; };
auto _get_cpt_type = [&] (const size_t pos) -> int { auto _get_flags = [&] (const size_t pos) -> codepoint_flags {
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_type(cpts[pos]) : CODEPOINT_TYPE_UNIDENTIFIED; static const codepoint_flags undef(codepoint_flags::UNDEFINED);
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : undef;
}; };
size_t _prev_end = offset_ini; size_t _prev_end = offset_ini;
@ -380,7 +372,7 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) { for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) {
const char32_t cpt = _get_cpt(pos); const char32_t cpt = _get_cpt(pos);
const int cpt_type = _get_cpt_type(pos); const auto flags = _get_flags(pos);
// regex: (?i:'s|'t|'re|'ve|'m|'ll|'d) // case insensitive // regex: (?i:'s|'t|'re|'ve|'m|'ll|'d) // case insensitive
if (cpt == '\'' && pos+1 < offset_end) { if (cpt == '\'' && pos+1 < offset_end) {
@ -401,10 +393,10 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
} }
// regex: [^\r\n\p{L}\p{N}]?\p{L}+ //####FIXME: the first \p{L} is correct? // regex: [^\r\n\p{L}\p{N}]?\p{L}+ //####FIXME: the first \p{L} is correct?
if (cpt != '\r' && cpt != '\n' && /*cpt_type != CODEPOINT_TYPE_LETTER &&*/ cpt_type != CODEPOINT_TYPE_NUMBER) { if (!(cpt == '\r' || cpt == '\n' || /*flags.is_letter |*/ flags.is_number)) {
if (cpt_type == CODEPOINT_TYPE_LETTER || _get_cpt_type(pos+1) == CODEPOINT_TYPE_LETTER) { // one or more letters if (flags.is_letter || _get_flags(pos+1).is_letter) { // one or more letters
pos++; pos++;
while (_get_cpt_type(pos) == CODEPOINT_TYPE_LETTER) { while (_get_flags(pos).is_letter) {
pos++; pos++;
} }
_add_token(pos); _add_token(pos);
@ -413,9 +405,9 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
} }
// regex: \p{N}{1,3} // regex: \p{N}{1,3}
if (cpt_type == CODEPOINT_TYPE_NUMBER) { if (flags.is_number) {
size_t ini = pos; size_t ini = pos;
while (_get_cpt_type(pos) == CODEPOINT_TYPE_NUMBER) { while (_get_flags(pos).is_number) {
if (++pos - ini >= 3 ) { if (++pos - ini >= 3 ) {
_add_token(pos); _add_token(pos);
ini = pos; ini = pos;
@ -426,14 +418,13 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
} }
// regex: <space>?[^\s\p{L}\p{N}]+[\r\n]* // regex: <space>?[^\s\p{L}\p{N}]+[\r\n]*
char32_t cpt2 = (cpt == ' ' ? _get_cpt(pos+1) : cpt); auto flags2 = (cpt == ' ' ? _get_flags(pos+1) : flags);
int cpt2_type = (cpt == ' ' ? _get_cpt_type(pos+1) : cpt_type); if (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number || flags2.is_undefined)) {
if (!unicode_cpt_is_whitespace(cpt2) && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_NUMBER && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED) {
pos += (cpt == ' '); pos += (cpt == ' ');
while (!unicode_cpt_is_whitespace(cpt2) && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_NUMBER && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED) { while (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number || flags2.is_undefined)) {
cpt2_type = _get_cpt_type(++pos); flags2 = _get_flags(++pos);
cpt2 = _get_cpt(pos);
} }
char32_t cpt2 = _get_cpt(pos);
while (cpt2 == '\r' || cpt2 == '\n') { while (cpt2 == '\r' || cpt2 == '\n') {
cpt2 = _get_cpt(++pos); cpt2 = _get_cpt(++pos);
} }
@ -443,7 +434,7 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
size_t num_whitespaces = 0; size_t num_whitespaces = 0;
size_t last_end_r_or_n = 0; size_t last_end_r_or_n = 0;
while (unicode_cpt_is_whitespace(_get_cpt(pos+num_whitespaces))) { while (_get_flags(pos+num_whitespaces).is_whitespace) {
char32_t cpt2 = _get_cpt(pos+num_whitespaces); char32_t cpt2 = _get_cpt(pos+num_whitespaces);
if (cpt2 == '\r' || cpt2 == '\n') { if (cpt2 == '\r' || cpt2 == '\n') {
last_end_r_or_n = pos + num_whitespaces + 1; last_end_r_or_n = pos + num_whitespaces + 1;
@ -589,15 +580,14 @@ std::string unicode_cpt_to_utf8(uint32_t cp) {
} }
std::vector<uint32_t> unicode_cpts_normalize_nfd(const std::vector<uint32_t> & cpts) { std::vector<uint32_t> unicode_cpts_normalize_nfd(const std::vector<uint32_t> & cpts) {
std::vector<uint32_t> result; auto comp = [] (const uint32_t cpt, const range_nfd & range) {
result.reserve(cpts.size()); return cpt < range.first;
};
std::vector<uint32_t> result(cpts.size());
for (size_t i = 0; i < cpts.size(); ++i) { for (size_t i = 0; i < cpts.size(); ++i) {
auto it = unicode_map_nfd.find(cpts[i]); const uint32_t cpt = cpts[i];
if (it == unicode_map_nfd.end()) { auto it = std::upper_bound(unicode_ranges_nfd.cbegin(), unicode_ranges_nfd.cend(), cpt, comp) - 1;
result.push_back(cpts[i]); result[i] = (it->first <= cpt && cpt <= it->last) ? it->nfd : cpt;
} else {
result.push_back(it->second);
}
} }
return result; return result;
} }
@ -611,31 +601,19 @@ std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8) {
return result; return result;
} }
int unicode_cpt_type(uint32_t cp) { codepoint_flags unicode_cpt_flags(const uint32_t cp) {
static std::unordered_map<uint32_t, int> cpt_types = unicode_cpt_type_map(); static const codepoint_flags undef(codepoint_flags::UNDEFINED);
const auto it = cpt_types.find(cp); static const auto cpt_flags = unicode_cpt_flags_array();
return it == cpt_types.end() ? CODEPOINT_TYPE_UNIDENTIFIED : it->second; return cp < cpt_flags.size() ? cpt_flags[cp] : undef;
} }
int unicode_cpt_type(const std::string & utf8) { codepoint_flags unicode_cpt_flags(const std::string & utf8) {
if (utf8.length() == 0) { static const codepoint_flags undef(codepoint_flags::UNDEFINED);
return CODEPOINT_TYPE_UNIDENTIFIED; if (utf8.empty()) {
return undef; // undefined
} }
size_t offset = 0; size_t offset = 0;
return unicode_cpt_type(unicode_cpt_from_utf8(utf8, offset)); return unicode_cpt_flags(unicode_cpt_from_utf8(utf8, offset));
}
bool unicode_cpt_is_whitespace(uint32_t cp) {
static const std::unordered_set<uint32_t> is_whitespace = [] {
std::unordered_set<uint32_t> is_whitespace;
for (auto p : unicode_ranges_whitespace) {
for (auto i = p.first; i <= p.second; ++i) {
is_whitespace.insert(i);
}
}
return is_whitespace;
}();
return (bool)is_whitespace.count(cp);
} }
std::string unicode_byte_to_utf8(uint8_t byte) { std::string unicode_byte_to_utf8(uint8_t byte) {
@ -656,21 +634,21 @@ char32_t unicode_tolower(char32_t cp) {
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs) { std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs) {
// unicode categories // unicode categories
static const std::map<std::string, int> k_ucat_enum = { static const std::map<std::string, int> k_ucat_enum = {
{ "\\p{N}", CODEPOINT_TYPE_NUMBER }, { "\\p{N}", codepoint_flags::NUMBER },
{ "\\p{L}", CODEPOINT_TYPE_LETTER }, { "\\p{L}", codepoint_flags::LETTER },
{ "\\p{P}", CODEPOINT_TYPE_PUNCTUATION }, { "\\p{P}", codepoint_flags::PUNCTUATION },
}; };
static const std::map<int, int> k_ucat_cpt = { static const std::map<int, int> k_ucat_cpt = {
{ CODEPOINT_TYPE_NUMBER, 0xD1 }, { codepoint_flags::NUMBER, 0xD1 },
{ CODEPOINT_TYPE_LETTER, 0xD2 }, { codepoint_flags::LETTER, 0xD2 },
{ CODEPOINT_TYPE_PUNCTUATION, 0xD3 }, { codepoint_flags::PUNCTUATION, 0xD3 },
}; };
static const std::map<int, std::string> k_ucat_map = { static const std::map<int, std::string> k_ucat_map = {
{ CODEPOINT_TYPE_NUMBER, "\x30-\x39" }, // 0-9 { codepoint_flags::NUMBER, "\x30-\x39" }, // 0-9
{ CODEPOINT_TYPE_LETTER, "\x41-\x5A\x61-\x7A" }, // A-Za-z { codepoint_flags::LETTER, "\x41-\x5A\x61-\x7A" }, // A-Za-z
{ CODEPOINT_TYPE_PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\} { codepoint_flags::PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\}
}; };
// compute collapsed codepoints only if needed by at least one regex // compute collapsed codepoints only if needed by at least one regex
@ -701,10 +679,10 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
continue; continue;
} }
const int cpt_type = unicode_cpt_type(cpts[i]); const int cpt_flag = unicode_cpt_flags(cpts[i]).category_flag();
if (k_ucat_cpt.find(cpt_type) != k_ucat_cpt.end()) { if (k_ucat_cpt.find(cpt_flag) != k_ucat_cpt.end()) {
text_collapsed[i] = k_ucat_cpt.at(cpt_type); text_collapsed[i] = k_ucat_cpt.at(cpt_flag);
} else { } else {
text_collapsed[i] = (char) 0xD0; // fallback text_collapsed[i] = (char) 0xD0; // fallback
} }

View file

@ -4,24 +4,56 @@
#include <string> #include <string>
#include <vector> #include <vector>
#define CODEPOINT_TYPE_UNIDENTIFIED 0 struct codepoint_flags {
#define CODEPOINT_TYPE_NUMBER 1 enum {
#define CODEPOINT_TYPE_LETTER 2 UNDEFINED = 0x0001,
#define CODEPOINT_TYPE_SEPARATOR 3 NUMBER = 0x0002, // regex: \p{N}
#define CODEPOINT_TYPE_ACCENT_MARK 4 LETTER = 0x0004, // regex: \p{L}
#define CODEPOINT_TYPE_PUNCTUATION 5 SEPARATOR = 0x0008, // regex: \p{Z}
#define CODEPOINT_TYPE_SYMBOL 6 ACCENT_MARK = 0x0010, // regex: \p{M}
#define CODEPOINT_TYPE_CONTROL 7 PUNCTUATION = 0x0020, // regex: \p{P}
SYMBOL = 0x0040, // regex: \p{S}
CONTROL = 0x0080, // regex: \p{C}
MASK_CATEGORIES = 0x00FF,
};
// codepoint type
uint16_t is_undefined : 1;
uint16_t is_number : 1; // regex: \p{N}
uint16_t is_letter : 1; // regex: \p{L}
uint16_t is_separator : 1; // regex: \p{Z}
uint16_t is_accent_mark : 1; // regex: \p{M}
uint16_t is_punctuation : 1; // regex: \p{P}
uint16_t is_symbol : 1; // regex: \p{S}
uint16_t is_control : 1; // regex: \p{C}
// helper flags
uint16_t is_whitespace : 1; // regex: \s
uint16_t is_lowercase : 1;
uint16_t is_uppercase : 1;
uint16_t is_nfd : 1;
// decode from uint16
inline codepoint_flags(const uint16_t flags=0) {
*reinterpret_cast<uint16_t*>(this) = flags;
}
inline uint16_t as_uint() const {
return *reinterpret_cast<const uint16_t*>(this);
}
inline uint16_t category_flag() const {
return this->as_uint() & MASK_CATEGORIES;
}
};
std::string unicode_cpt_to_utf8(uint32_t cp); std::string unicode_cpt_to_utf8(uint32_t cp);
std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8); std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8);
std::vector<uint32_t> unicode_cpts_normalize_nfd(const std::vector<uint32_t> & cpts); std::vector<uint32_t> unicode_cpts_normalize_nfd(const std::vector<uint32_t> & cpts);
int unicode_cpt_type(uint32_t cp); codepoint_flags unicode_cpt_flags(const uint32_t cp);
int unicode_cpt_type(const std::string & utf8); codepoint_flags unicode_cpt_flags(const std::string & utf8);
bool unicode_cpt_is_whitespace(uint32_t cp);
std::string unicode_byte_to_utf8(uint8_t byte); std::string unicode_byte_to_utf8(uint8_t byte);
uint8_t unicode_utf8_to_byte(const std::string & utf8); uint8_t unicode_utf8_to_byte(const std::string & utf8);