Merge branch 'upstream' into concedo_experimental

# Conflicts:
#	.github/actions/windows-setup-curl/action.yml
#	.github/workflows/build-linux-cross.yml
#	README.md
#	common/CMakeLists.txt
#	examples/parallel/README.md
#	examples/parallel/parallel.cpp
#	ggml/src/ggml-sycl/element_wise.cpp
#	ggml/src/ggml-vulkan/CMakeLists.txt
#	tools/server/README.md
This commit is contained in:
Concedo 2025-05-18 23:27:53 +08:00
commit 59300dbdf5
25 changed files with 694 additions and 550 deletions

View file

@ -238,14 +238,19 @@ jobs:
matrix: matrix:
include: include:
- build: 'cpu-x64' - build: 'cpu-x64'
arch: 'x64'
defines: '-G "Ninja Multi-Config" -D CMAKE_TOOLCHAIN_FILE=cmake/x64-windows-llvm.cmake -DGGML_NATIVE=OFF -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON -DGGML_OPENMP=OFF' defines: '-G "Ninja Multi-Config" -D CMAKE_TOOLCHAIN_FILE=cmake/x64-windows-llvm.cmake -DGGML_NATIVE=OFF -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON -DGGML_OPENMP=OFF'
#- build: 'openblas-x64' #- build: 'openblas-x64'
# arch: 'x64'
# defines: '-G "Ninja Multi-Config" -D CMAKE_TOOLCHAIN_FILE=cmake/x64-windows-llvm.cmake -DGGML_NATIVE=OFF -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON -DGGML_OPENMP=OFF -DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS -DBLAS_INCLUDE_DIRS="$env:RUNNER_TEMP/openblas/include" -DBLAS_LIBRARIES="$env:RUNNER_TEMP/openblas/lib/openblas.lib"' # defines: '-G "Ninja Multi-Config" -D CMAKE_TOOLCHAIN_FILE=cmake/x64-windows-llvm.cmake -DGGML_NATIVE=OFF -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON -DGGML_OPENMP=OFF -DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS -DBLAS_INCLUDE_DIRS="$env:RUNNER_TEMP/openblas/include" -DBLAS_LIBRARIES="$env:RUNNER_TEMP/openblas/lib/openblas.lib"'
- build: 'vulkan-x64' - build: 'vulkan-x64'
arch: 'x64'
defines: '-DGGML_NATIVE=OFF -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON -DGGML_VULKAN=ON' defines: '-DGGML_NATIVE=OFF -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON -DGGML_VULKAN=ON'
- build: 'cpu-arm64' - build: 'cpu-arm64'
arch: 'arm64'
defines: '-G "Ninja Multi-Config" -D CMAKE_TOOLCHAIN_FILE=cmake/arm64-windows-llvm.cmake -DGGML_NATIVE=OFF' defines: '-G "Ninja Multi-Config" -D CMAKE_TOOLCHAIN_FILE=cmake/arm64-windows-llvm.cmake -DGGML_NATIVE=OFF'
- build: 'opencl-adreno-arm64' - build: 'opencl-adreno-arm64'
arch: 'arm64'
defines: '-G "Ninja Multi-Config" -D CMAKE_TOOLCHAIN_FILE=cmake/arm64-windows-llvm.cmake -DCMAKE_PREFIX_PATH="$env:RUNNER_TEMP/opencl-arm64-release" -DGGML_OPENCL=ON -DGGML_OPENCL_USE_ADRENO_KERNELS=ON' defines: '-G "Ninja Multi-Config" -D CMAKE_TOOLCHAIN_FILE=cmake/arm64-windows-llvm.cmake -DCMAKE_PREFIX_PATH="$env:RUNNER_TEMP/opencl-arm64-release" -DGGML_OPENCL=ON -DGGML_OPENCL_USE_ADRENO_KERNELS=ON'
steps: steps:
@ -312,6 +317,8 @@ jobs:
- name: libCURL - name: libCURL
id: get_libcurl id: get_libcurl
uses: ./.github/actions/windows-setup-curl uses: ./.github/actions/windows-setup-curl
with:
architecture: ${{ matrix.arch == 'x64' && 'win64' || 'win64a' }}
- name: Build - name: Build
id: cmake_build id: cmake_build
@ -339,7 +346,7 @@ jobs:
env: env:
CURL_PATH: ${{ steps.get_libcurl.outputs.curl_path }} CURL_PATH: ${{ steps.get_libcurl.outputs.curl_path }}
run: | run: |
Copy-Item $env:CURL_PATH\bin\libcurl-x64.dll .\build\bin\Release\libcurl-x64.dll Copy-Item $env:CURL_PATH\bin\libcurl-${{ matrix.arch }}.dll .\build\bin\Release\
7z a llama-${{ steps.tag.outputs.name }}-bin-win-${{ matrix.build }}.zip .\build\bin\Release\* 7z a llama-${{ steps.tag.outputs.name }}-bin-win-${{ matrix.build }}.zip .\build\bin\Release\*
- name: Upload artifacts - name: Upload artifacts

View file

@ -2586,7 +2586,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
[](common_params & params, int value) { [](common_params & params, int value) {
params.n_junk = value; params.n_junk = value;
} }
).set_examples({LLAMA_EXAMPLE_PASSKEY})); ).set_examples({LLAMA_EXAMPLE_PASSKEY, LLAMA_EXAMPLE_PARALLEL}));
add_opt(common_arg( add_opt(common_arg(
{"--pos"}, "N", {"--pos"}, "N",
string_format("position of the passkey in the junk text (default: %d)", params.i_pos), string_format("position of the passkey in the junk text (default: %d)", params.i_pos),
@ -2649,7 +2649,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
[](common_params & params) { [](common_params & params) {
params.is_pp_shared = true; params.is_pp_shared = true;
} }
).set_examples({LLAMA_EXAMPLE_BENCH})); ).set_examples({LLAMA_EXAMPLE_BENCH, LLAMA_EXAMPLE_PARALLEL}));
add_opt(common_arg( add_opt(common_arg(
{"-npp"}, "n0,n1,...", {"-npp"}, "n0,n1,...",
"number of prompt tokens", "number of prompt tokens",
@ -2881,6 +2881,16 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.chat_template = read_file(value); params.chat_template = read_file(value);
} }
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CHAT_TEMPLATE_FILE")); ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CHAT_TEMPLATE_FILE"));
add_opt(common_arg(
{"--no-prefill-assistant"},
string_format(
"whether to prefill the assistant's response if the last message is an assistant message (default: prefill enabled)\n"
"when this flag is set, if the last message is an assistant message then it will be treated as a full message and not prefilled\n"
),
[](common_params & params) {
params.prefill_assistant = false;
}
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_NO_PREFILL_ASSISTANT"));
add_opt(common_arg( add_opt(common_arg(
{"-sps", "--slot-prompt-similarity"}, "SIMILARITY", {"-sps", "--slot-prompt-similarity"}, "SIMILARITY",
string_format("how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f, 0.0 = disabled)\n", params.slot_prompt_similarity), string_format("how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f, 0.0 = disabled)\n", params.slot_prompt_similarity),

View file

@ -364,6 +364,7 @@ struct common_params {
bool use_jinja = false; // NOLINT bool use_jinja = false; // NOLINT
bool enable_chat_template = true; bool enable_chat_template = true;
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
bool prefill_assistant = true; // if true, any trailing assistant message will be prefilled into the response
std::vector<std::string> api_keys; std::vector<std::string> api_keys;

View file

@ -415,6 +415,13 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H64,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H64,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H64,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H64,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96,
@ -1362,6 +1369,13 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128, flash_attn_ext_q8_0_hk192_hv128, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128, flash_attn_ext_q8_0_hk192_hv128, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512, flash_attn_ext_q8_0_hk576_hv512, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512, flash_attn_ext_q8_0_hk576_hv512, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64, flash_attn_ext_vec_f16_h64, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64, flash_attn_ext_vec_bf16_h64, has_simdgroup_reduction && use_bfloat);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64, flash_attn_ext_vec_q4_0_h64, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H64, flash_attn_ext_vec_q4_1_h64, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H64, flash_attn_ext_vec_q5_0_h64, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H64, flash_attn_ext_vec_q5_1_h64, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H64, flash_attn_ext_vec_q8_0_h64, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96, flash_attn_ext_vec_f16_h96, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96, flash_attn_ext_vec_f16_h96, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96, flash_attn_ext_vec_bf16_h96, has_simdgroup_reduction && use_bfloat); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96, flash_attn_ext_vec_bf16_h96, has_simdgroup_reduction && use_bfloat);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96, flash_attn_ext_vec_q4_0_h96, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96, flash_attn_ext_vec_q4_0_h96, has_simdgroup_reduction);
@ -4358,7 +4372,7 @@ static bool ggml_metal_encode_node(
// TODO: add vec kernels for (ne00%64 == 0) and maybe also for (ne00%32 == 0) // TODO: add vec kernels for (ne00%64 == 0) and maybe also for (ne00%32 == 0)
// for now avoiding mainly to keep the number of templates/kernels a bit lower // for now avoiding mainly to keep the number of templates/kernels a bit lower
// these are now trivial to add after: https://github.com/ggml-org/llama.cpp/pull/12612 // these are now trivial to add after: https://github.com/ggml-org/llama.cpp/pull/12612
if (ne01 >= 20 || (ne00%128 != 0 && ne00 != 96 && ne00 != 192 && ne00 != 576)) { if (ne01 >= 20 || (ne00%128 != 0 && ne00 != 64 && ne00 != 96 && ne00 != 192 && ne00 != 576)) {
switch (src1->type) { switch (src1->type) {
case GGML_TYPE_F16: case GGML_TYPE_F16:
{ {
@ -4539,6 +4553,24 @@ static bool ggml_metal_encode_node(
use_vec_kernel = true; use_vec_kernel = true;
switch (ne00) { switch (ne00) {
case 64:
{
switch (src1->type) {
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64].pipeline; break;
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64].pipeline; break;
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64].pipeline; break;
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H64].pipeline; break;
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H64].pipeline; break;
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H64].pipeline; break;
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H64].pipeline; break;
default:
{
GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
GGML_LOG_ERROR("add template specialization for this type\n");
GGML_ABORT("add template specialization for this type");
}
}
} break;
case 96: case 96:
{ {
switch (src1->type) { switch (src1->type) {

View file

@ -4124,6 +4124,16 @@ kernel void kernel_flash_attn_ext_vec(
typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t; typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t;
template [[host_name("kernel_flash_attn_ext_vec_f16_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 64, 64, 8>;
#if defined(GGML_METAL_USE_BF16)
template [[host_name("kernel_flash_attn_ext_vec_bf16_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 64, 64, 8>;
#endif
template [[host_name("kernel_flash_attn_ext_vec_q4_0_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 64, 64, 8>;
template [[host_name("kernel_flash_attn_ext_vec_q4_1_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 64, 64, 8>;
template [[host_name("kernel_flash_attn_ext_vec_q5_0_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 64, 64, 8>;
template [[host_name("kernel_flash_attn_ext_vec_q5_1_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 64, 64, 8>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 64, 64, 8>;
template [[host_name("kernel_flash_attn_ext_vec_f16_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 96, 96, 4>; template [[host_name("kernel_flash_attn_ext_vec_f16_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 96, 96, 4>;
#if defined(GGML_METAL_USE_BF16) #if defined(GGML_METAL_USE_BF16)
template [[host_name("kernel_flash_attn_ext_vec_bf16_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 96, 96, 4>; template [[host_name("kernel_flash_attn_ext_vec_bf16_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 96, 96, 4>;

View file

@ -5896,10 +5896,17 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
vk_pipeline *pipelines; vk_pipeline *pipelines;
bool small_rows = N <= get_fa_num_small_rows(path); bool small_rows = N <= get_fa_num_small_rows(path);
// coopmat1 does not actually support "small rows" (it needs 16 rows).
// So use scalar instead.
if (small_rows && path == FA_COOPMAT1) { if (small_rows && path == FA_COOPMAT1) {
path = FA_SCALAR; path = FA_SCALAR;
} }
// scalar is faster than coopmat2 when N==1
if (N == 1 && path == FA_COOPMAT2) {
path = FA_SCALAR;
}
bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32; bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32;
switch (path) { switch (path) {

View file

@ -9,60 +9,13 @@
#extension GL_KHR_shader_subgroup_shuffle : enable #extension GL_KHR_shader_subgroup_shuffle : enable
#include "types.comp" #include "types.comp"
#include "flash_attn_base.comp"
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (constant_id = 0) const uint32_t WorkGroupSize = 128;
layout (constant_id = 1) const uint32_t Br = 1;
layout (constant_id = 2) const uint32_t Bc = 32;
layout (constant_id = 3) const uint32_t D = 32;
layout (constant_id = 5) const uint32_t D_split = 16;
const uint32_t D_per_thread = D / D_split; const uint32_t D_per_thread = D / D_split;
const uint32_t cols_per_iter = WorkGroupSize / D_split; const uint32_t cols_per_iter = WorkGroupSize / D_split;
const uint32_t cols_per_thread = Bc / cols_per_iter; const uint32_t cols_per_thread = Bc / cols_per_iter;
layout (push_constant) uniform parameter {
uint32_t N;
uint32_t KV;
uint32_t ne1;
uint32_t ne2;
uint32_t ne3;
uint32_t neq2;
uint32_t neq3;
uint32_t nek2;
uint32_t nek3;
uint32_t nev2;
uint32_t nev3;
uint32_t nem1;
uint32_t nb01;
uint32_t nb02;
uint32_t nb03;
uint32_t nb11;
uint32_t nb12;
uint32_t nb13;
uint32_t nb21;
uint32_t nb22;
uint32_t nb23;
uint32_t nb31;
float scale;
float max_bias;
float logit_softcap;
uint32_t mask;
uint32_t n_head_log2;
float m0;
float m1;
uint32_t gqa_ratio;
uint32_t split_kv;
uint32_t k_num;
} p;
layout (binding = 0) readonly buffer Q {float data_q[];}; layout (binding = 0) readonly buffer Q {float data_q[];};
layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];}; layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];};
@ -71,39 +24,6 @@ layout (binding = 1) readonly buffer KV4 {f16vec4 data_kv4[];};
layout (binding = 2) readonly buffer V {float16_t data_v[];}; layout (binding = 2) readonly buffer V {float16_t data_v[];};
layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];}; layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];};
layout (binding = 3) readonly buffer M {float16_t data_m[];}; layout (binding = 3) readonly buffer M {float16_t data_m[];};
layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};
#if defined(A_TYPE_PACKED16)
#define BINDING_IDX_K 0
#define BINDING_IDX_V 1
layout (binding = 1) readonly buffer KV_PACKED16 {A_TYPE_PACKED16 data_packed16[];} kv_packed[2];
#endif
#if defined(DATA_A_Q4_0)
#define BLOCK_BYTE_SIZE 18
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
uint vui_lo = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
uint vui_hi = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
uint shift = (iqs & 0x10) >> 2;
vui_lo >>= shift;
vui_hi >>= shift;
return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
}
#endif
#if defined(DATA_A_Q8_0)
#define BLOCK_BYTE_SIZE 34
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
const i8vec2 v0 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
const i8vec2 v1 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
}
#endif
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
// Store the output when doing grouped query attention. // Store the output when doing grouped query attention.
// Rows index by Q's dimension 2, and the first N rows are valid. // Rows index by Q's dimension 2, and the first N rows are valid.
@ -114,27 +34,6 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TY
return elem; return elem;
} }
// Store column zero. This is used to save per-row m and L values for split_k.
ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
{
if (r < N && c == 0) {
uint32_t offset = iq2 + r;
data_o[o_offset + offset] = D_TYPE(elem);
}
return elem;
}
// Load the slope matrix, indexed by Q's dimension 2.
ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
{
const uint32_t h = iq2 + (r % p.gqa_ratio);
const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
return ACC_TYPE(pow(base, ACC_TYPE(exph)));
}
shared FLOAT_TYPE tmpsh[WorkGroupSize]; shared FLOAT_TYPE tmpsh[WorkGroupSize];
shared vec4 tmpshv4[WorkGroupSize]; shared vec4 tmpshv4[WorkGroupSize];
@ -146,58 +45,12 @@ void main() {
init_iq_shmem(gl_WorkGroupSize); init_iq_shmem(gl_WorkGroupSize);
#endif #endif
const uint32_t tid = gl_LocalInvocationIndex; init_indices();
const uint32_t N = p.N;
const uint32_t KV = p.KV;
const uint32_t tid = gl_LocalInvocationIndex;
const uint32_t d_tid = gl_LocalInvocationIndex % D_split; const uint32_t d_tid = gl_LocalInvocationIndex % D_split;
const uint32_t col_tid = gl_LocalInvocationIndex / D_split; const uint32_t col_tid = gl_LocalInvocationIndex / D_split;
uint32_t i = gl_WorkGroupID.x;
uint32_t split_k_index = 0;
if (p.k_num > 1) {
i = 0;
split_k_index = gl_WorkGroupID.x;
}
const uint32_t Tr = CEIL_DIV(N, Br);
const uint32_t start_j = split_k_index * p.split_kv / Bc;
const uint32_t end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc);
// When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y.
// When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2.
const uint32_t iq2 = gl_WorkGroupID.y * p.gqa_ratio;
const uint32_t iq3 = gl_WorkGroupID.z;
// broadcast factors
const uint32_t rk2 = p.neq2/p.nek2;
const uint32_t rk3 = p.neq3/p.nek3;
const uint32_t rv2 = p.neq2/p.nev2;
const uint32_t rv3 = p.neq3/p.nev3;
// k indices
const uint32_t ik3 = iq3 / rk3;
const uint32_t ik2 = iq2 / rk2;
// v indices
const uint32_t iv3 = iq3 / rv3;
const uint32_t iv2 = iq2 / rv2;
// nb?1 are already divided by the type size and are in units of elements.
// When using grouped query attention, Q is indexed by iq2, so the stride
// should be nb02 (which is in bytes).
uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01;
uint32_t k_stride = p.nb11;
uint32_t v_stride = p.nb21;
// When using grouped query attention, all rows use the same mask (stride 0).
// "p.gqa_ratio >> 16" is just a roundabout way of writing zero
// that prevents the compiler from folding the "&" through the select
// and breaking the alignment detection.
uint32_t m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV;
uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4; uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
[[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) { [[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) {

View file

@ -0,0 +1,162 @@
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (constant_id = 0) const uint32_t WorkGroupSize = 128;
layout (constant_id = 1) const uint32_t Br = 1;
layout (constant_id = 2) const uint32_t Bc = 32;
layout (constant_id = 3) const uint32_t D = 32;
layout (constant_id = 4) const uint32_t Clamp = 0;
layout (constant_id = 5) const uint32_t D_split = 16;
layout (push_constant) uniform parameter {
uint32_t N;
uint32_t KV;
uint32_t ne1;
uint32_t ne2;
uint32_t ne3;
uint32_t neq2;
uint32_t neq3;
uint32_t nek2;
uint32_t nek3;
uint32_t nev2;
uint32_t nev3;
uint32_t nem1;
uint32_t nb01;
uint32_t nb02;
uint32_t nb03;
uint32_t nb11;
uint32_t nb12;
uint32_t nb13;
uint32_t nb21;
uint32_t nb22;
uint32_t nb23;
uint32_t nb31;
float scale;
float max_bias;
float logit_softcap;
uint32_t mask;
uint32_t n_head_log2;
float m0;
float m1;
uint32_t gqa_ratio;
uint32_t split_kv;
uint32_t k_num;
} p;
layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};
#if defined(A_TYPE_PACKED16)
#define BINDING_IDX_K 0
#define BINDING_IDX_V 1
layout (binding = 1) readonly buffer KV_PACKED16 {A_TYPE_PACKED16 data_packed16[];} kv_packed[2];
#endif
#if defined(DATA_A_Q4_0)
#define BLOCK_BYTE_SIZE 18
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
uint vui_lo = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
uint vui_hi = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
uint shift = (iqs & 0x10) >> 2;
vui_lo >>= shift;
vui_hi >>= shift;
return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
}
#endif
#if defined(DATA_A_Q8_0)
#define BLOCK_BYTE_SIZE 34
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
const i8vec2 v0 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
const i8vec2 v1 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
}
#endif
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
// Store column zero. This is used to save per-row m and L values for split_k.
ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
{
if (r < N && c == 0) {
uint32_t offset = iq2 + r;
data_o[o_offset + offset] = D_TYPE(elem);
}
return elem;
}
// Load the slope matrix, indexed by Q's dimension 2.
ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
{
const uint32_t h = iq2 + (r % p.gqa_ratio);
const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
return ACC_TYPE(pow(base, ACC_TYPE(exph)));
}
uint32_t i, N, KV, split_k_index, Tr, start_j, end_j,
iq2, iq3, rk2, rk3, rv2, rv3, ik2, ik3, iv2, iv3,
q_stride, k_stride, v_stride, m_stride;
void init_indices()
{
N = p.N;
KV = p.KV;
i = gl_WorkGroupID.x;
split_k_index = 0;
if (p.k_num > 1) {
i = 0;
split_k_index = gl_WorkGroupID.x;
}
Tr = CEIL_DIV(N, Br);
start_j = split_k_index * p.split_kv / Bc;
end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc);
// When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y.
// When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2.
iq2 = gl_WorkGroupID.y * p.gqa_ratio;
iq3 = gl_WorkGroupID.z;
// broadcast factors
rk2 = p.neq2/p.nek2;
rk3 = p.neq3/p.nek3;
rv2 = p.neq2/p.nev2;
rv3 = p.neq3/p.nev3;
// k indices
ik3 = iq3 / rk3;
ik2 = iq2 / rk2;
// v indices
iv3 = iq3 / rv3;
iv2 = iq2 / rv2;
// nb?1 are already divided by the type size and are in units of elements.
// When using grouped query attention, Q is indexed by iq2, so the stride
// should be nb02 (which is in bytes).
q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01;
k_stride = p.nb11;
v_stride = p.nb21;
// When using grouped query attention, all rows use the same mask (stride 0).
// "p.gqa_ratio >> 16" is just a roundabout way of writing zero
// that prevents the compiler from folding the "&" through the select
// and breaking the alignment detection.
m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV;
}

View file

@ -11,14 +11,7 @@
#extension GL_KHR_cooperative_matrix : enable #extension GL_KHR_cooperative_matrix : enable
#include "types.comp" #include "types.comp"
#include "flash_attn_base.comp"
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (constant_id = 1) const uint32_t Br = 1;
layout (constant_id = 2) const uint32_t Bc = 32;
layout (constant_id = 3) const uint32_t D = 32;
layout (constant_id = 5) const uint32_t D_split = 16;
const uint32_t D_per_thread = D / D_split; const uint32_t D_per_thread = D / D_split;
const uint32_t row_split = 4; const uint32_t row_split = 4;
@ -26,46 +19,6 @@ const uint32_t rows_per_thread = Br / row_split;
const uint32_t cols_per_iter = gl_WorkGroupSize.x / D_split / row_split; const uint32_t cols_per_iter = gl_WorkGroupSize.x / D_split / row_split;
const uint32_t cols_per_thread = Bc / cols_per_iter; const uint32_t cols_per_thread = Bc / cols_per_iter;
layout (push_constant) uniform parameter {
uint32_t N;
uint32_t KV;
uint32_t ne1;
uint32_t ne2;
uint32_t ne3;
uint32_t neq2;
uint32_t neq3;
uint32_t nek2;
uint32_t nek3;
uint32_t nev2;
uint32_t nev3;
uint32_t nem1;
uint32_t nb01;
uint32_t nb02;
uint32_t nb03;
uint32_t nb11;
uint32_t nb12;
uint32_t nb13;
uint32_t nb21;
uint32_t nb22;
uint32_t nb23;
uint32_t nb31;
float scale;
float max_bias;
float logit_softcap;
uint32_t mask;
uint32_t n_head_log2;
float m0;
float m1;
uint32_t gqa_ratio;
uint32_t split_kv;
uint32_t k_num;
} p;
layout (binding = 0) readonly buffer Q {float data_q[];}; layout (binding = 0) readonly buffer Q {float data_q[];};
layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];}; layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];};
@ -74,39 +27,6 @@ layout (binding = 1) readonly buffer KV4 {f16vec4 data_kv4[];};
layout (binding = 2) readonly buffer V {float16_t data_v[];}; layout (binding = 2) readonly buffer V {float16_t data_v[];};
layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];}; layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];};
layout (binding = 3) readonly buffer M {float16_t data_m[];}; layout (binding = 3) readonly buffer M {float16_t data_m[];};
layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};
#if defined(A_TYPE_PACKED16)
#define BINDING_IDX_K 0
#define BINDING_IDX_V 1
layout (binding = 1) readonly buffer KV_PACKED16 {A_TYPE_PACKED16 data_packed16[];} kv_packed[2];
#endif
#if defined(DATA_A_Q4_0)
#define BLOCK_BYTE_SIZE 18
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
uint vui_lo = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
uint vui_hi = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
uint shift = (iqs & 0x10) >> 2;
vui_lo >>= shift;
vui_hi >>= shift;
return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
}
#endif
#if defined(DATA_A_Q8_0)
#define BLOCK_BYTE_SIZE 34
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
const i8vec2 v0 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
const i8vec2 v1 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
}
#endif
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
// Store the output when doing grouped query attention. // Store the output when doing grouped query attention.
// Rows index by Q's dimension 2, and the first N rows are valid. // Rows index by Q's dimension 2, and the first N rows are valid.
@ -117,27 +37,6 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TY
return elem; return elem;
} }
// Store column zero. This is used to save per-row m and L values for split_k.
ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
{
if (r < N && c == 0) {
uint32_t offset = iq2 + r;
data_o[o_offset + offset] = D_TYPE(elem);
}
return elem;
}
// Load the slope matrix, indexed by Q's dimension 2.
ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
{
const uint32_t h = iq2 + (r % p.gqa_ratio);
const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
return ACC_TYPE(pow(base, ACC_TYPE(exph)));
}
// These need to be supported N,M values for a MatBc x MatBr x 16 coopmatmuladd // These need to be supported N,M values for a MatBc x MatBr x 16 coopmatmuladd
const uint32_t MatBr = 16; const uint32_t MatBr = 16;
const uint32_t MatBc = 16; const uint32_t MatBc = 16;
@ -162,9 +61,9 @@ void main() {
init_iq_shmem(gl_WorkGroupSize); init_iq_shmem(gl_WorkGroupSize);
#endif #endif
init_indices();
const uint32_t tid = gl_LocalInvocationIndex; const uint32_t tid = gl_LocalInvocationIndex;
const uint32_t N = p.N;
const uint32_t KV = p.KV;
const uint32_t threads_per_rowgroup = gl_WorkGroupSize.x / row_split; const uint32_t threads_per_rowgroup = gl_WorkGroupSize.x / row_split;
const uint32_t row_tid = gl_LocalInvocationIndex / threads_per_rowgroup; const uint32_t row_tid = gl_LocalInvocationIndex / threads_per_rowgroup;
@ -173,51 +72,6 @@ void main() {
#define tile_row(r) (row_tid * rows_per_thread + (r)) #define tile_row(r) (row_tid * rows_per_thread + (r))
uint32_t i = gl_WorkGroupID.x;
uint32_t split_k_index = 0;
if (p.k_num > 1) {
i = 0;
split_k_index = gl_WorkGroupID.x;
}
const uint32_t Tr = CEIL_DIV(N, Br);
const uint32_t start_j = split_k_index * p.split_kv / Bc;
const uint32_t end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc);
// When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y.
// When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2.
const uint32_t iq2 = gl_WorkGroupID.y * p.gqa_ratio;
const uint32_t iq3 = gl_WorkGroupID.z;
// broadcast factors
const uint32_t rk2 = p.neq2/p.nek2;
const uint32_t rk3 = p.neq3/p.nek3;
const uint32_t rv2 = p.neq2/p.nev2;
const uint32_t rv3 = p.neq3/p.nev3;
// k indices
const uint32_t ik3 = iq3 / rk3;
const uint32_t ik2 = iq2 / rk2;
// v indices
const uint32_t iv3 = iq3 / rv3;
const uint32_t iv2 = iq2 / rv2;
// nb?1 are already divided by the type size and are in units of elements.
// When using grouped query attention, Q is indexed by iq2, so the stride
// should be nb02 (which is in bytes).
uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01;
uint32_t k_stride = p.nb11;
uint32_t v_stride = p.nb21;
// When using grouped query attention, all rows use the same mask (stride 0).
// "p.gqa_ratio >> 16" is just a roundabout way of writing zero
// that prevents the compiler from folding the "&" through the select
// and breaking the alignment detection.
uint32_t m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV;
uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4; uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
[[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) { [[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) {

View file

@ -18,62 +18,12 @@
#include "types.comp" #include "types.comp"
#include "dequant_funcs_cm2.comp" #include "dequant_funcs_cm2.comp"
#include "flash_attn_base.comp"
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (constant_id = 1) const uint32_t Br = 32;
layout (constant_id = 2) const uint32_t Bc = 32;
layout (constant_id = 3) const uint32_t D = 32;
layout (constant_id = 4) const uint32_t Clamp = gl_CooperativeMatrixClampModeConstantNV;
layout (push_constant) uniform parameter {
uint32_t N;
uint32_t KV;
uint32_t ne1;
uint32_t ne2;
uint32_t ne3;
uint32_t neq2;
uint32_t neq3;
uint32_t nek2;
uint32_t nek3;
uint32_t nev2;
uint32_t nev3;
uint32_t nem1;
uint32_t nb01;
uint32_t nb02;
uint32_t nb03;
uint32_t nb11;
uint32_t nb12;
uint32_t nb13;
uint32_t nb21;
uint32_t nb22;
uint32_t nb23;
uint32_t nb31;
float scale;
float max_bias;
float logit_softcap;
uint32_t mask;
uint32_t n_head_log2;
float m0;
float m1;
uint32_t gqa_ratio;
uint32_t split_kv;
uint32_t k_num;
} p;
layout (binding = 0) readonly buffer Q {uint8_t data_q[];}; layout (binding = 0) readonly buffer Q {uint8_t data_q[];};
layout (binding = 1) readonly buffer K {uint8_t data_k[];}; layout (binding = 1) readonly buffer K {uint8_t data_k[];};
layout (binding = 2) readonly buffer V {uint8_t data_v[];}; layout (binding = 2) readonly buffer V {uint8_t data_v[];};
layout (binding = 3) readonly buffer M {uint8_t data_m[];}; layout (binding = 3) readonly buffer M {uint8_t data_m[];};
layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
ACC_TYPE maxReduce(const in ACC_TYPE x, const in ACC_TYPE y) { ACC_TYPE maxReduce(const in ACC_TYPE x, const in ACC_TYPE y) {
return max(x, y); return max(x, y);
@ -118,67 +68,12 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TY
return elem; return elem;
} }
// Store column zero. This is used to save per-row m and L values for split_k.
ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
{
if (r < N && c == 0) {
uint32_t offset = iq2 + r;
data_o[o_offset + offset] = D_TYPE(elem);
}
return elem;
}
// Load the slope matrix, indexed by Q's dimension 2.
ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
{
const uint32_t h = iq2 + (r % p.gqa_ratio);
const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
return ACC_TYPE(pow(base, ACC_TYPE(exph)));
}
void main() { void main() {
#ifdef NEEDS_INIT_IQ_SHMEM #ifdef NEEDS_INIT_IQ_SHMEM
init_iq_shmem(gl_WorkGroupSize); init_iq_shmem(gl_WorkGroupSize);
#endif #endif
const uint32_t N = p.N; init_indices();
const uint32_t KV = p.KV;
uint32_t i = gl_WorkGroupID.x;
uint32_t split_k_index = 0;
if (p.k_num > 1) {
i = 0;
split_k_index = gl_WorkGroupID.x;
}
const uint32_t Tr = CEIL_DIV(N, Br);
const uint32_t start_j = split_k_index * p.split_kv / Bc;
const uint32_t end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc);
// When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y.
// When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2.
const uint32_t iq2 = gl_WorkGroupID.y * p.gqa_ratio;
const uint32_t iq3 = gl_WorkGroupID.z;
// broadcast factors
const uint32_t rk2 = p.neq2/p.nek2;
const uint32_t rk3 = p.neq3/p.nek3;
const uint32_t rv2 = p.neq2/p.nev2;
const uint32_t rv3 = p.neq3/p.nev3;
// k indices
const uint32_t ik3 = iq3 / rk3;
const uint32_t ik2 = iq2 / rk2;
// v indices
const uint32_t iv3 = iq3 / rv3;
const uint32_t iv2 = iq2 / rv2;
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutQ = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutQ = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
tensorLayoutNV<2, Clamp> tensorLayoutK = createTensorLayoutNV(2, Clamp); tensorLayoutNV<2, Clamp> tensorLayoutK = createTensorLayoutNV(2, Clamp);
@ -195,17 +90,6 @@ void main() {
tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, D); tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, D);
tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, D); tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, D);
// nb?1 are already divided by the type size and are in units of elements.
// When using grouped query attention, Q is indexed by iq2, so the stride
// should be nb02 (which is in bytes).
uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01;
uint32_t k_stride = p.nb11;
uint32_t v_stride = p.nb21;
// When using grouped query attention, all rows use the same mask (stride 0).
// "p.gqa_ratio >> 16" is just a roundabout way of writing zero
// that prevents the compiler from folding the "&" through the select
// and breaking the alignment detection.
uint32_t m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV;
// hint to the compiler that strides are aligned for the aligned variant of the shader // hint to the compiler that strides are aligned for the aligned variant of the shader
if (Clamp != gl_CooperativeMatrixClampModeConstantNV) if (Clamp != gl_CooperativeMatrixClampModeConstantNV)
{ {

View file

@ -12,7 +12,7 @@ Current version indicated by LITEVER below.
--> -->
<script> <script>
const LITEVER = 242; const LITEVER = 243;
const urlParams = new URLSearchParams(window.location.search); const urlParams = new URLSearchParams(window.location.search);
var localflag = urlParams.get('local'); //this will be replaced automatically in embedded kcpp var localflag = urlParams.get('local'); //this will be replaced automatically in embedded kcpp
const STORAGE_PREFIX = (localflag?"e_":"")+"kaihordewebui_"; const STORAGE_PREFIX = (localflag?"e_":"")+"kaihordewebui_";
@ -3256,6 +3256,7 @@ Current version indicated by LITEVER below.
instruct_has_latex: true, instruct_has_latex: true,
placeholder_tags: true, placeholder_tags: true,
render_special_tags: false, render_special_tags: false,
inject_randomness_seed: -1,
request_logprobs: false, request_logprobs: false,
persist_session: true, persist_session: true,
speech_synth: 0, //0 is disabled, 1000 is xtts speech_synth: 0, //0 is disabled, 1000 is xtts
@ -3392,7 +3393,7 @@ Current version indicated by LITEVER below.
rep_pen_slope: defaultsettings.rep_pen_slope, rep_pen_slope: defaultsettings.rep_pen_slope,
sampler_order: defaultsettings.sampler_order sampler_order: defaultsettings.sampler_order
}, },
{"preset":"Simple Logical","description":"A very predictable preset with low randomness.","temp":0.3,"dynatemp_range":0.0,"dynatemp_exponent":1.0,"smoothing_factor":0.0,"nsigma":0.0,"top_k":100,"top_p":0.6,"min_p":0.0,"presence_penalty":0.0,"top_a":0,"typical":1,"tfs":1,"rep_pen":1.02,"rep_pen_range":360,"rep_pen_slope":0.7,"sampler_order":[6,0,1,3,4,2,5]},{"preset":"Simple Balanced","description":"A good balanced preset with medium randomness.","temp":0.75,"dynatemp_range":0.0,"dynatemp_exponent":1.0,"smoothing_factor":0.0,"nsigma":0.0,"top_k":100,"top_p":0.92,"min_p":0.0,"presence_penalty":0.0,"top_a":0,"typical":1,"tfs":1,"rep_pen":1.07,"rep_pen_range":360,"rep_pen_slope":0.7,"sampler_order":[6,0,1,3,4,2,5]},{"preset":"Simple Creative","description":"A wild and unpredictable preset with higher randomness.","temp":1.0,"dynatemp_range":0.0,"dynatemp_exponent":1.0,"smoothing_factor":0.0,"nsigma":0.0,"top_k":100,"top_p":0.98,"min_p":0.0,"presence_penalty":0.0,"top_a":0,"typical":1,"tfs":1,"rep_pen":1.15,"rep_pen_range":360,"rep_pen_slope":0.7,"sampler_order":[6,0,1,3,4,2,5]},{"preset":"Basic Min-P","description":"A good default for Min-P, only works on backends with min-p.","temp":1.25,"dynatemp_range":0.0,"dynatemp_exponent":1.0,"smoothing_factor":0.0,"nsigma":0.0,"top_k":0,"top_p":1,"min_p":0.1,"presence_penalty":0.0,"top_a":0,"typical":1,"tfs":1,"rep_pen":1.03,"rep_pen_range":360,"rep_pen_slope":0.7,"sampler_order":[6,5,0,1,3,4,2]},{"preset":"Basic Top-nsigma","description":"A good default for Top-nsigma, only works on backends with Top-nsigma.","temp":1,"dynatemp_range":0.0,"dynatemp_exponent":1.0,"smoothing_factor":0.0,"nsigma":1.0,"top_k":0,"top_p":1,"min_p":0.01,"presence_penalty":0.0,"top_a":0,"typical":1,"tfs":1,"rep_pen":1,"rep_pen_range":360,"rep_pen_slope":0.7,"sampler_order":[6,5,0,1,3,4,2]},{"preset":"Basic DynaTemp","description":"A good default for DynaTemp, only works on backends with it.","temp":1.25,"dynatemp_range":0.75,"dynatemp_exponent":1.0,"smoothing_factor":0.0,"nsigma":0.0,"top_k":0,"top_p":1,"min_p":0.05,"presence_penalty":0.0,"top_a":0,"typical":1,"tfs":1,"rep_pen":1.03,"rep_pen_range":360,"rep_pen_slope":0.7,"sampler_order":[6,5,0,1,3,4,2]},{"preset":"Basic SmoothSample","description":"A good default for Smooth Sampling, only works on backends with it.","temp":1.0,"dynatemp_range":0.0,"dynatemp_exponent":1.0,"smoothing_factor":0.25,"top_k":0,"top_p":1,"min_p":0.05,"presence_penalty":0.0,"top_a":0,"typical":1,"tfs":1,"rep_pen":1.03,"rep_pen_range":360,"rep_pen_slope":0.7,"sampler_order":[6,5,0,1,3,4,2]},{"preset":"Basic SillyTavern","description":"Similar to default preset used in SillyTavern.","temp":0.75,"dynatemp_range":0.0,"dynatemp_exponent":1.0,"smoothing_factor":0.0,"nsigma":0.0,"top_k":40,"top_p":0.6,"min_p":0,"presence_penalty":0.0,"top_a":0,"typical":1,"tfs":1.0,"rep_pen":1.18,"rep_pen_range":1024,"rep_pen_slope":0.8,"sampler_order":[6,0,1,3,4,2,5]},{"preset":"Neutral (Disabled)","description":"Sets all samplers neutralized, allowing you to customize your own.","temp":1.0,"dynatemp_range":0.0,"dynatemp_exponent":1.0,"smoothing_factor":0.0,"nsigma":0.0,"top_k":200,"top_p":1.0,"min_p":0.0,"presence_penalty":0.0,"top_a":0,"typical":1,"tfs":1,"rep_pen":1.0,"rep_pen_range":360,"rep_pen_slope":0.7,"sampler_order":[6,0,1,3,4,2,5]},{"preset":"CoherentCreativity (Legacy)","description":"Legacy preset. A good balance between coherence, creativity, and quality of prose.","rep_pen":1.2,"rep_pen_range":360,"rep_pen_slope":0,"sampler_order":[6,5,0,2,3,1,4],"temp":0.5,"dynatemp_range":0.0,"dynatemp_exponent":1.0,"smoothing_factor":0.0,"nsigma":0.0,"tfs":0.99,"top_a":0,"top_k":0,"top_p":1,"min_p":0.0,"presence_penalty":0.0,"typical":1},{"preset":"Godlike (Legacy)","description":"Legacy preset. Makes AI give a descriptive and sensual output.","temp":0.7,"dynatemp_range":0.0,"dynatemp_exponent":1.0,"smoothing_factor":0.0,"nsigma":0.0,"top_k":0,"top_p":0.5,"min_p":0.0,"presence_penalty":0.0,"top_a":0.75,"typical":0.19,"tfs":0.97,"rep_pen":1.1,"rep_pen_range":1024,"rep_pen_slope":0.7,"sampler_order":[6,5,4,3,2,1,0]},{"preset":"LiminalDrift (Legacy)","description":"Legacy preset. Sometimes surreal situations arise based on information already present in the story.","temp":0.66,"dynatemp_range":0.0,"dynatemp_exponent":1.0,"smoothing_factor":0.0,"nsigma":0.0,"top_k":0,"top_p":1,"min_p":0.0,"presence_penalty":0.0,"top_a":0.96,"typical":0.6,"tfs":1,"rep_pen":1.1,"rep_pen_range":1024,"rep_pen_slope":0.7,"sampler_order":[6,4,5,1,0,2,3]} {"preset":"Simple Logical","description":"A very predictable preset with low randomness.","temp":0.3,"dynatemp_range":0.0,"dynatemp_exponent":1.0,"smoothing_factor":0.0,"nsigma":0.0,"top_k":100,"top_p":0.6,"min_p":0.0,"presence_penalty":0.0,"top_a":0,"typical":1,"tfs":1,"rep_pen":1.02,"rep_pen_range":360,"rep_pen_slope":0.7,"sampler_order":[6,0,1,3,4,2,5]},{"preset":"Simple Balanced","description":"A good balanced preset with medium randomness.","temp":0.75,"dynatemp_range":0.0,"dynatemp_exponent":1.0,"smoothing_factor":0.0,"nsigma":0.0,"top_k":100,"top_p":0.92,"min_p":0.0,"presence_penalty":0.0,"top_a":0,"typical":1,"tfs":1,"rep_pen":1.07,"rep_pen_range":360,"rep_pen_slope":0.7,"sampler_order":[6,0,1,3,4,2,5]},{"preset":"Simple Creative","description":"A wild and unpredictable preset with higher randomness.","temp":1.0,"dynatemp_range":0.0,"dynatemp_exponent":1.0,"smoothing_factor":0.0,"nsigma":0.0,"top_k":100,"top_p":0.98,"min_p":0.0,"presence_penalty":0.0,"top_a":0,"typical":1,"tfs":1,"rep_pen":1.15,"rep_pen_range":360,"rep_pen_slope":0.7,"sampler_order":[6,0,1,3,4,2,5]},{"preset":"Basic Min-P","description":"A good default for Min-P, only works on backends with min-p.","temp":1.25,"dynatemp_range":0.0,"dynatemp_exponent":1.0,"smoothing_factor":0.0,"nsigma":0.0,"top_k":0,"top_p":1,"min_p":0.1,"presence_penalty":0.0,"top_a":0,"typical":1,"tfs":1,"rep_pen":1.03,"rep_pen_range":360,"rep_pen_slope":0.7,"sampler_order":[6,5,0,1,3,4,2]},{"preset":"Basic Top-nsigma","description":"A good default for Top-nsigma, only works on backends with Top-nsigma.","temp":1,"dynatemp_range":0.0,"dynatemp_exponent":1.0,"smoothing_factor":0.0,"nsigma":1.0,"top_k":0,"top_p":1,"min_p":0.01,"presence_penalty":0.0,"top_a":0,"typical":1,"tfs":1,"rep_pen":1,"rep_pen_range":360,"rep_pen_slope":0.7,"sampler_order":[6,5,0,1,3,4,2]},{"preset":"Basic DynaTemp","description":"A good default for DynaTemp, only works on backends with it.","temp":1.25,"dynatemp_range":0.75,"dynatemp_exponent":1.0,"smoothing_factor":0.0,"nsigma":0.0,"top_k":0,"top_p":1,"min_p":0.05,"presence_penalty":0.0,"top_a":0,"typical":1,"tfs":1,"rep_pen":1.03,"rep_pen_range":360,"rep_pen_slope":0.7,"sampler_order":[6,5,0,1,3,4,2]},{"preset":"Basic SmoothSample","description":"A good default for Smooth Sampling, only works on backends with it.","temp":1.0,"dynatemp_range":0.0,"dynatemp_exponent":1.0,"smoothing_factor":0.25,"nsigma":0.0,"top_k":0,"top_p":1,"min_p":0.05,"presence_penalty":0.0,"top_a":0,"typical":1,"tfs":1,"rep_pen":1.03,"rep_pen_range":360,"rep_pen_slope":0.7,"sampler_order":[6,5,0,1,3,4,2]},{"preset":"Basic SillyTavern","description":"Similar to default preset used in SillyTavern.","temp":0.75,"dynatemp_range":0.0,"dynatemp_exponent":1.0,"smoothing_factor":0.0,"nsigma":0.0,"top_k":40,"top_p":0.6,"min_p":0,"presence_penalty":0.0,"top_a":0,"typical":1,"tfs":1.0,"rep_pen":1.18,"rep_pen_range":1024,"rep_pen_slope":0.8,"sampler_order":[6,0,1,3,4,2,5]},{"preset":"Immortal","description":"Modernized version of the Godlike preset, designed for creative and longer story co-writing use.","temp":0.7,"dynatemp_range":0.0,"dynatemp_exponent":1.0,"smoothing_factor":0.0,"nsigma":1.75,"top_k":0,"top_p":1.0,"min_p":0.0,"presence_penalty":0.0,"top_a":0.75,"typical":0.19,"tfs":0.97,"rep_pen":1.1,"rep_pen_range":1024,"rep_pen_slope":0.7,"sampler_order":[6,5,4,3,2,1,0]},{"preset":"Neutral (Disabled)","description":"Sets all samplers neutralized, allowing you to customize your own.","temp":1.0,"dynatemp_range":0.0,"dynatemp_exponent":1.0,"smoothing_factor":0.0,"nsigma":0.0,"top_k":200,"top_p":1.0,"min_p":0.0,"presence_penalty":0.0,"top_a":0,"typical":1,"tfs":1,"rep_pen":1.0,"rep_pen_range":360,"rep_pen_slope":0.7,"sampler_order":[6,0,1,3,4,2,5]},{"preset":"CoherentCreativity (Legacy)","description":"Legacy preset. A good balance between coherence, creativity, and quality of prose.","rep_pen":1.2,"rep_pen_range":360,"rep_pen_slope":0,"sampler_order":[6,5,0,2,3,1,4],"temp":0.5,"dynatemp_range":0.0,"dynatemp_exponent":1.0,"smoothing_factor":0.0,"nsigma":0.0,"tfs":0.99,"top_a":0,"top_k":0,"top_p":1,"min_p":0.0,"presence_penalty":0.0,"typical":1},{"preset":"Godlike (Legacy)","description":"Legacy preset. Makes AI give a descriptive and sensual output.","temp":0.7,"dynatemp_range":0.0,"dynatemp_exponent":1.0,"smoothing_factor":0.0,"nsigma":0.0,"top_k":0,"top_p":0.5,"min_p":0.0,"presence_penalty":0.0,"top_a":0.75,"typical":0.19,"tfs":0.97,"rep_pen":1.1,"rep_pen_range":1024,"rep_pen_slope":0.7,"sampler_order":[6,5,4,3,2,1,0]},{"preset":"LiminalDrift (Legacy)","description":"Legacy preset. Sometimes surreal situations arise based on information already present in the story.","temp":0.66,"dynatemp_range":0.0,"dynatemp_exponent":1.0,"smoothing_factor":0.0,"nsigma":0.0,"top_k":0,"top_p":1,"min_p":0.0,"presence_penalty":0.0,"top_a":0.96,"typical":0.6,"tfs":1,"rep_pen":1.1,"rep_pen_range":1024,"rep_pen_slope":0.7,"sampler_order":[6,4,5,1,0,2,3]}
]; ];
const instructpresets = [ const instructpresets = [
@ -3588,6 +3589,15 @@ Current version indicated by LITEVER below.
//truncate to first 3 bytes //truncate to first 3 bytes
return hsh.substring(0, 6); return hsh.substring(0, 6);
}; };
function basic_lcg(seed) { // simple RNG for reproducible dice rolls
var m = Math.pow(2, 35) - 31;
var a = 185852;
var s = seed % m;
return function () {
s = (s * a) % m;
return s / m;
};
}
function import_props_into_object(existingObj, objToImport) { //import new fields from one object into another while preserving exsting function import_props_into_object(existingObj, objToImport) { //import new fields from one object into another while preserving exsting
for (var k in objToImport) { for (var k in objToImport) {
existingObj[k] = objToImport[k]; existingObj[k] = objToImport[k];
@ -4325,7 +4335,7 @@ Current version indicated by LITEVER below.
text = remove_all_instruct_tags(text); text = remove_all_instruct_tags(text);
// Replace {{user}} and other placeholders // Replace {{user}} and other placeholders
text = replace_placeholders(text); text = replace_placeholders(text,false,false,false);
return text; return text;
} }
@ -4347,7 +4357,7 @@ Current version indicated by LITEVER below.
inputtxt = replaceAll(inputtxt,instructsysplaceholder_end.trim(),get_instruct_systag_end(false)); inputtxt = replaceAll(inputtxt,instructsysplaceholder_end.trim(),get_instruct_systag_end(false));
return inputtxt; return inputtxt;
} }
function replace_noninstruct_placeholders(inputtxt,escape=false) function replace_noninstruct_placeholders(inputtxt,escape=false,isgametext=true,persistent_countmap=null)
{ {
let firstopponent = localsettings.chatopponent; let firstopponent = localsettings.chatopponent;
if (firstopponent && firstopponent.includes("||$||")) { if (firstopponent && firstopponent.includes("||$||")) {
@ -4373,15 +4383,92 @@ Current version indicated by LITEVER below.
inputtxt = replaceAll(inputtxt,localsettings.placeholder_tags_data[i].p,localsettings.placeholder_tags_data[i].r); inputtxt = replaceAll(inputtxt,localsettings.placeholder_tags_data[i].p,localsettings.placeholder_tags_data[i].r);
} }
} }
if(localsettings.inject_randomness_seed>0)
{
// define helper functions
let dicenotation = function(formula,seed=0){
formula = formula.trim();
if (/^\d+$/.test(formula)) { formula = `1d${formula}`; }
const pieces = formula.match(/^(\d+)?d(\d+)([x*\u00D7]\d*\.?\d+)?([+-]\d*\.?\d+)?$/i);
if (!pieces) {
return ""; //if the dice test fails we still need to return the original unswapped content!
}
let RNG = seed ? basic_lcg(seed) : Math.random;
let numDice = parseInt(pieces[1]) || 1;
let dieSides = parseInt(pieces[2]);
if (isNaN(numDice) || isNaN(dieSides) || dieSides < 0 || dieSides > 99999) { //dice above 99999 cause extreme slowdown
console.warn("Avoid potential dice issues");
return "";
}
let total = 0;
for (let i = 0; i < numDice; ++i) {
total += Math.floor(RNG() * dieSides) + 1; // Standard dice rolls are 1-based
}
total *= pieces[3] ? parseFloat(pieces[3].substring(1)) : 1;
total += pieces[4] ? parseFloat(pieces[4]) : 0;
return String(total);
}
let pickfromlist = function(listString,seed=0){
let RNG = seed ? basic_lcg(seed) : Math.random;
const list = listString.includes('::') ? listString.split('::') : listString.replace(/\\,/g, '%%COMMA%%').split(',').map(item => item.trim().replace(/%%COMMA%%/g, ','));
return (list.length===0 ? '' : list[Math.floor(RNG()*list.length)]);
}
//don't waste time if we dont have any such tags
let lowerinputtxt = inputtxt.toLowerCase();
if (lowerinputtxt.includes("\{\{pick") || lowerinputtxt.includes("\{\{random") || lowerinputtxt.includes("\{\{roll"))
{
let stablereg = "";
if(!isgametext)
{
inputtxt = inputtxt.replace(/\{\{roll\s?:?:?([^\}\n]{1,500})\}\}/gi, function (m,matchValue) {
return dicenotation(matchValue, 0);
});
inputtxt = inputtxt.replace(/\{\{random\s?:?:?([^\}\n]{1,500})\}\}/gi, function (m,listString) {
return pickfromlist(listString,0);
});
stablereg = /\{\{(pick)\s?:?:?([^\}\n]{1,500})\}\}/gi;
}
else
{
stablereg = /\{\{(pick|random|roll)\s?:?:?([^\}\n]{1,500})\}\}/gi;
}
let countmap = (persistent_countmap?persistent_countmap:new Map());
let parts = [];
let lastIndex = 0;
let match;
while ((match = stablereg.exec(inputtxt)) !== null) {
const tagType = match[1].toLowerCase();
const matchtxt = match[2];
parts.push(inputtxt.substring(lastIndex, match.index));
let count = countmap.get(matchtxt) || 0;
countmap.set(matchtxt, count + 1);
let hashstr = matchtxt + localsettings.inject_randomness_seed;
if (isgametext || tagType !== "pick") {
hashstr += count;
}
let seed = parseInt(cyrb_hash(hashstr), 16);
let replacement = (tagType === "roll" ? dicenotation : pickfromlist)(matchtxt, seed);
parts.push(replacement);
lastIndex = stablereg.lastIndex;
}
parts.push(inputtxt.substring(lastIndex)); // Add remaining text after the last match
inputtxt = parts.join('');
}
}
return inputtxt; return inputtxt;
} }
//if alwaysreplace, then settings are not considered, otherwise checks settings //if alwaysreplace, then settings are not considered, otherwise checks settings
function replace_placeholders(inputtxt, escape=false, alwaysreplace=false) //if isgametext is true, then keeping the random values stable will be prioritized
function replace_placeholders(inputtxt, escape=false, alwaysreplace=false, isgametext=true)
{ {
inputtxt = replace_instruct_placeholders(inputtxt); inputtxt = replace_instruct_placeholders(inputtxt);
if(alwaysreplace || localsettings.placeholder_tags) if(alwaysreplace || localsettings.placeholder_tags)
{ {
inputtxt = replace_noninstruct_placeholders(inputtxt,escape); inputtxt = replace_noninstruct_placeholders(inputtxt,escape,isgametext,null);
} }
return inputtxt; return inputtxt;
} }
@ -4775,15 +4862,18 @@ Current version indicated by LITEVER below.
var unzip = new Zlib.Unzip(compressed); var unzip = new Zlib.Unzip(compressed);
var filenames = unzip.getFilenames(); var filenames = unzip.getFilenames();
let foundfile = filenames.filter(x=>x.includes(".json")); let foundfile = filenames.filter(x=>x.includes(".json"));
let foundfile2 = filenames.filter(x=>x.includes("story.json"));
if(foundfile.length>0) if(foundfile.length>0)
{ {
try { try {
var plainfile = unzip.decompress(filenames[0]); let correctfile = foundfile[0];
let readtxt = ""; if(foundfile2.length>0)
for(let i=0;i<plainfile.length;++i)
{ {
readtxt += String.fromCharCode(plainfile[i]); correctfile = foundfile2[0];
} }
var plainfile = unzip.decompress(correctfile);
const decoder = new TextDecoder('utf-8');
const readtxt = decoder.decode(plainfile);
var decoded = JSON.parse(readtxt); var decoded = JSON.parse(readtxt);
console.log(decoded); console.log(decoded);
return decoded; return decoded;
@ -4794,6 +4884,77 @@ Current version indicated by LITEVER below.
} }
return null; return null;
}; };
function extractRisuData(arr) {
function strToBytes(str) {
return new TextEncoder().encode(str);
}
function indexOfSequence(haystack, needle, from = 0) {
for (let i = from; i <= haystack.length - needle.length; i++) {
let match = true;
for (let j = 0; j < needle.length; j++) {
if (haystack[i + j] !== needle[j]) {
match = false;
break;
}
}
if (match) return i;
}
return -1;
}
function lastIndexOfSequence(haystack, needle) {
for (let i = haystack.length - needle.length; i >= 0; i--) {
let match = true;
for (let j = 0; j < needle.length; j++) {
if (haystack[i + j] !== needle[j]) {
match = false;
break;
}
}
if (match) return i;
}
return -1;
}
console.log("Attempting RISU import...");
const charCardFinderSig = strToBytes("chara_card_v");
const charcardpos = indexOfSequence(arr, charCardFinderSig);
if(charcardpos!=-1)
{
const zipStartSig = strToBytes("PK\x03\x04");
const zipEndSig = strToBytes("PK\x05\x06");
const start = indexOfSequence(arr, zipStartSig);
const end = lastIndexOfSequence(arr, zipEndSig);
if (start !== -1 && end !== -1 && end > start && charcardpos > start && charcardpos < end) {
try {
const zipData = arr.slice(start, end + 22); // Add 22 bytes for EOCD record
var unzip = new Zlib.Unzip(zipData);
var filenames = unzip.getFilenames();
let foundfile = filenames.filter(x => x=="card.json");
if (foundfile.length > 0) {
var plainfile = unzip.decompress(foundfile[0]);
const decoder = new TextDecoder('utf-8');
const readtxt = decoder.decode(plainfile);
var decoded = JSON.parse(readtxt);
console.log(decoded);
return decoded;
}
console.log("RISU card missing card.json.");
return null;
} catch (e) {
console.log("Error decoding RISU card: " + e);
return null;
}
} else {
console.log("Not a valid RISU card.");
return null;
}
}
console.log("Not a RISU card.");
return null;
}
function generate_compressed_story(save_images,export_settings,export_aesthetic_settings) { function generate_compressed_story(save_images,export_settings,export_aesthetic_settings) {
//encode the current story into a sharable url //encode the current story into a sharable url
//a tiny json format which gets compressed by LZMA then b64url //a tiny json format which gets compressed by LZMA then b64url
@ -7006,61 +7167,80 @@ Current version indicated by LITEVER below.
} }
} }
} catch (e) { } catch (e) {
console.log(e) console.log(e);
//attempt to parse it as a png file //attempt to parse it as a png file
var pngfr = new FileReader(); function handlePngLoadDone(img) {
pngfr.onload = function (img) { const data = pngfr.result;
var data = pngfr.result; const arr = new Uint8Array(data);
var arr = new Uint8Array(data);
var result = convertTavernPng(arr); function setImgAsAvatar(data)
if (result != null) { {
load_tavern_obj(result); compressImage(data, (compressedImageURI, aspectratio) => {
//replace portraits
compressImage(data, (compressedImageURI, aspectratio)=>{
aestheticInstructUISettings.AI_portrait = compressedImageURI; aestheticInstructUISettings.AI_portrait = compressedImageURI;
document.getElementById('portrait_ratio_AI').value = aspectratio.toFixed(2); document.getElementById('portrait_ratio_AI').value = aspectratio.toFixed(2);
refreshAestheticPreview(true); refreshAestheticPreview(true);
render_gametext(); render_gametext();
}, true, AVATAR_PX); }, true, AVATAR_PX);
} }
else {
//attempt to read as WEBP
result = getTavernExifJSON(arr);
if (result != null) {
load_tavern_obj(result);
}
else {
//attempt to read as KAISTORY
try {
result = UnzipKAISTORYFile(arr);
} catch (error) {
console.log("Unzip failed: " + error);
result = null;
}
if (result != null) { // 1. Try Tavern PNG
kai_json_load(result,false); let result = convertTavernPng(arr);
} if (result) {
else { load_tavern_obj(result);
if (selectedFilename.endsWith(".txt")) { setImgAsAvatar(data);
msgboxYesNo("Could not load selected file!<br><span class=\"color_red\">It appears to be invalid or corrupted!</span><br><br>Do you still want to import it as plaintext?", "Loading Failed", return;
() => {
//raw text import
restart_new_game(false);
gametext_arr.push(text);
render_gametext(true);
sync_multiplayer(true);
update_for_sidepanel();
}, null, true)
} else {
msgbox("Could not load selected file. Is it valid?");
}
}
}
} }
};
// 2. Try Tavern EXIF
result = getTavernExifJSON(arr);
if (result) {
load_tavern_obj(result);
setImgAsAvatar(data);
return;
}
// 3. Try Risu V3
result = extractRisuData(arr);
if (result) {
load_tavern_obj(result);
setImgAsAvatar(data);
return;
}
// 4. Try KAISTORY
try {
result = UnzipKAISTORYFile(arr);
if (result) {
kai_json_load(result, false);
return;
}
} catch (error) {
console.log("Unzip failed: " + error);
}
// 5. Fallback to plaintext if .txt
if (selectedFilename.endsWith(".txt")) {
msgboxYesNo(
"Could not load selected file!<br><span class=\"color_red\">It appears to be invalid or corrupted!</span><br><br>Do you still want to import it as plaintext?",
"Loading Failed",
() => {
restart_new_game(false);
gametext_arr.push(text);
render_gametext(true);
sync_multiplayer(true);
update_for_sidepanel();
},
null,
true
);
} else {
msgbox("Could not load selected file. Is it valid?");
}
}
const pngfr = new FileReader();
pngfr.onload = handlePngLoadDone;
pngfr.readAsArrayBuffer(selectedFile); pngfr.readAsArrayBuffer(selectedFile);
} }
@ -7507,7 +7687,7 @@ Current version indicated by LITEVER below.
let selectedgreeting = ""; let selectedgreeting = "";
let load_tav_obj_confirm_p1 = function(usechatmode) // need second input for alt greeting let load_tav_obj_confirm_p1 = function(usechatmode) // need second input for alt greeting
{ {
if(obj.spec=="chara_card_v2" && obj.data!=null) if((obj.spec=="chara_card_v2"||obj.spec=="chara_card_v3") && obj.data!=null)
{ {
obj = obj.data; obj = obj.data;
} }
@ -7973,8 +8153,8 @@ Current version indicated by LITEVER below.
gametext_arr = []; gametext_arr = [];
if(!localsettings.placeholder_tags) //do a one-time replace instead if(!localsettings.placeholder_tags) //do a one-time replace instead
{ {
greeting = replace_placeholders(greeting, false, true); greeting = replace_placeholders(greeting, false, true, false); //this is all onetime replacements, no need for stability
combinedmem = replace_placeholders(combinedmem, false, true); combinedmem = replace_placeholders(combinedmem, false, true, false);
} }
if(greeting) if(greeting)
{ {
@ -8033,7 +8213,7 @@ Current version indicated by LITEVER below.
let prompttxt = temp_scenario.prompt; let prompttxt = temp_scenario.prompt;
if(!localsettings.placeholder_tags) //do a one-time replace instead if(!localsettings.placeholder_tags) //do a one-time replace instead
{ {
prompttxt = replace_placeholders(prompttxt, false, true); prompttxt = replace_placeholders(prompttxt, false, true, false);
} }
gametext_arr.push(prompttxt); gametext_arr.push(prompttxt);
} }
@ -8041,14 +8221,14 @@ Current version indicated by LITEVER below.
current_anote = temp_scenario.authorsnote; current_anote = temp_scenario.authorsnote;
if(!localsettings.placeholder_tags) if(!localsettings.placeholder_tags)
{ {
current_anote = replace_placeholders(current_anote, false, true); current_anote = replace_placeholders(current_anote, false, true, false);
} }
} }
if (temp_scenario.memory != "") { if (temp_scenario.memory != "") {
current_memory = temp_scenario.memory; current_memory = temp_scenario.memory;
if(!localsettings.placeholder_tags) if(!localsettings.placeholder_tags)
{ {
current_memory = replace_placeholders(current_memory, false, true); current_memory = replace_placeholders(current_memory, false, true, false);
} }
} }
if (temp_scenario.image && temp_scenario.image != "") { if (temp_scenario.image && temp_scenario.image != "") {
@ -8253,6 +8433,7 @@ Current version indicated by LITEVER below.
userInput = userInput.split(/characterhub\.org\//i)[1].split("#")[0].split("?")[0]; userInput = userInput.split(/characterhub\.org\//i)[1].split("#")[0].split("?")[0];
} }
userInput = userInput.endsWith('/') ? userInput.slice(0, -1) : userInput; userInput = userInput.endsWith('/') ? userInput.slice(0, -1) : userInput;
userInput = userInput.startsWith('/') ? userInput.slice(1) : userInput;
return userInput; return userInput;
}, },
fetch: (userInput) => { fetch: (userInput) => {
@ -8270,16 +8451,8 @@ Current version indicated by LITEVER below.
}; };
//try to obtain the full portrait image //try to obtain the full portrait image
return fetch("https://api.chub.ai/api/characters/download", { return fetch(`https://avatars.charhub.io/avatars/${userInput}/chara_card_v2.png`, {
method: 'POST', method: 'GET',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({
"format": "tavern",
"fullPath": userInput,
"version": "main"
}),
referrerPolicy: 'no-referrer', referrerPolicy: 'no-referrer',
}) })
.then(rb => { .then(rb => {
@ -11656,6 +11829,7 @@ Current version indicated by LITEVER below.
document.getElementById("instruct_has_markdown").checked = localsettings.instruct_has_markdown; document.getElementById("instruct_has_markdown").checked = localsettings.instruct_has_markdown;
document.getElementById("instruct_has_latex").checked = localsettings.instruct_has_latex; document.getElementById("instruct_has_latex").checked = localsettings.instruct_has_latex;
document.getElementById("placeholder_tags").checked = localsettings.placeholder_tags; document.getElementById("placeholder_tags").checked = localsettings.placeholder_tags;
document.getElementById("inject_randomness_seed").value = localsettings.inject_randomness_seed;
document.getElementById("run_in_background").checked = run_in_background; document.getElementById("run_in_background").checked = run_in_background;
document.getElementById("auto_ctxlen").checked = localsettings.auto_ctxlen; document.getElementById("auto_ctxlen").checked = localsettings.auto_ctxlen;
document.getElementById("auto_genamt").checked = localsettings.auto_genamt; document.getElementById("auto_genamt").checked = localsettings.auto_genamt;
@ -12121,6 +12295,7 @@ Current version indicated by LITEVER below.
localsettings.instruct_has_markdown = (document.getElementById("instruct_has_markdown").checked ? true : false); localsettings.instruct_has_markdown = (document.getElementById("instruct_has_markdown").checked ? true : false);
localsettings.instruct_has_latex = (document.getElementById("instruct_has_latex").checked ? true : false); localsettings.instruct_has_latex = (document.getElementById("instruct_has_latex").checked ? true : false);
localsettings.placeholder_tags = (document.getElementById("placeholder_tags").checked ? true : false); localsettings.placeholder_tags = (document.getElementById("placeholder_tags").checked ? true : false);
localsettings.inject_randomness_seed = document.getElementById("inject_randomness_seed").value;
run_in_background = (document.getElementById("run_in_background").checked ? true : false); run_in_background = (document.getElementById("run_in_background").checked ? true : false);
background_audio_loop(run_in_background); background_audio_loop(run_in_background);
localsettings.generate_images_model = document.getElementById("generate_images_model").value; localsettings.generate_images_model = document.getElementById("generate_images_model").value;
@ -13084,6 +13259,11 @@ Current version indicated by LITEVER below.
documentdb_chunksize = 800; documentdb_chunksize = 800;
documentdb_data = ""; documentdb_data = "";
} }
if(localsettings.inject_randomness_seed>0)
{
new_randomness_seed();
localsettings.inject_randomness_seed = document.getElementById("inject_randomness_seed").value;
}
warn_unsaved = false; warn_unsaved = false;
last_used_saveslot = -1; last_used_saveslot = -1;
show_corpo_leftpanel(false); show_corpo_leftpanel(false);
@ -13113,6 +13293,11 @@ Current version indicated by LITEVER below.
} }
function new_randomness_seed()
{
document.getElementById("inject_randomness_seed").value = (Math.floor(Math.random() * 99999)); //keep within a simple range
}
function btn_editmode() function btn_editmode()
{ {
document.getElementById("allowediting").checked = true; document.getElementById("allowediting").checked = true;
@ -13434,6 +13619,10 @@ Current version indicated by LITEVER below.
function do_manual_gen_image(sentence, base64img="") //b64 for img2img function do_manual_gen_image(sentence, base64img="") //b64 for img2img
{ {
if(localsettings.placeholder_tags)
{
sentence = replace_placeholders(sentence)
}
generate_new_image(sentence, base64img); generate_new_image(sentence, base64img);
document.getElementById("btn_genimg").disabled = true; document.getElementById("btn_genimg").disabled = true;
document.getElementById("btn_genimg2").disabled = true; document.getElementById("btn_genimg2").disabled = true;
@ -14959,6 +15148,8 @@ Current version indicated by LITEVER below.
//only do this processing if memory or anote is not blank //only do this processing if memory or anote is not blank
if (truncated_memory.length > 0 || current_anote.length > 0) if (truncated_memory.length > 0 || current_anote.length > 0)
{ {
truncated_memory = replace_placeholders(truncated_memory,false,false,true);
truncated_anote = replace_placeholders(truncated_anote,false,false,false);
if(!is_using_kcpp_with_added_memory()) if(!is_using_kcpp_with_added_memory())
{ {
let augmented_len = truncated_memory.length + truncated_context.length + truncated_anote.length; let augmented_len = truncated_memory.length + truncated_context.length + truncated_anote.length;
@ -14992,8 +15183,7 @@ Current version indicated by LITEVER below.
} }
} }
truncated_memory = replace_placeholders(truncated_memory); truncated_context = replace_placeholders(truncated_context,false,false,true);
truncated_context = replace_placeholders(truncated_context);
if(is_using_kcpp_with_added_memory()) if(is_using_kcpp_with_added_memory())
{ {
@ -19351,6 +19541,7 @@ Current version indicated by LITEVER below.
let incomplete_resp = (synchro_pending_stream!="" || pending_response_id!=""); let incomplete_resp = (synchro_pending_stream!="" || pending_response_id!="");
let countmap = new Map();
for(var i=0;i<chatunits.length;++i) for(var i=0;i<chatunits.length;++i)
{ {
let curr = chatunits[i]; let curr = chatunits[i];
@ -19359,7 +19550,7 @@ Current version indicated by LITEVER below.
processed_msg = apply_display_only_regex(processed_msg); processed_msg = apply_display_only_regex(processed_msg);
if(processed_msg && processed_msg!="") if(processed_msg && processed_msg!="")
{ {
processed_msg = replace_noninstruct_placeholders(processed_msg,true); processed_msg = replace_noninstruct_placeholders(processed_msg,true,true,countmap);
let codeblockcount = (processed_msg.match(/```/g) || []).length; let codeblockcount = (processed_msg.match(/```/g) || []).length;
if(codeblockcount>0 && codeblockcount%2!=0 ) if(codeblockcount>0 && codeblockcount%2!=0 )
{ {
@ -21428,9 +21619,9 @@ Current version indicated by LITEVER below.
allText = replaceAll(allText,recentTextStr,""); allText = replaceAll(allText,recentTextStr,"");
// Ensure placeholders are replaced to allow searching for user / character // Ensure placeholders are replaced to allow searching for user / character
allText = replace_search_placeholders(allText) allText = replace_search_placeholders(allText);
searchStr = replace_search_placeholders(searchStr) searchStr = replace_search_placeholders(searchStr);
recentTextStr = replace_search_placeholders(recentTextStr) recentTextStr = replace_search_placeholders(recentTextStr);
let i = 0, startLoc = 0; let i = 0, startLoc = 0;
while (startLoc < allText.length && i < Number.MAX_SAFE_INTEGER) { while (startLoc < allText.length && i < Number.MAX_SAFE_INTEGER) {
@ -22811,6 +23002,8 @@ Current version indicated by LITEVER below.
</div> </div>
<table id="placeholder_replace_table" class="settinglabel" style="text-align: center; border-spacing: 3px 2px; border-collapse: separate;"> <table id="placeholder_replace_table" class="settinglabel" style="text-align: center; border-spacing: 3px 2px; border-collapse: separate;">
</table> </table>
<div style="padding:3px" class="settinglabel justifyleft">
Randomness Seed: <span class="helpicon">?<span class="helptext">A seed to generate randomness for {{pick}}, {{roll}} and {{random}}. Set to -1 to disable these three placeholders.</span></span><input class="settinglabel miniinput" style="margin-left:4px;width:70px;" type="text" inputmode="decimal" placeholder="0" value="" id="inject_randomness_seed"><button title="Regen" type="button" class="btn btn-primary bg_green" style="padding:1px 3px;margin:0px 0px 0px auto;font-size:10px;" onclick="new_randomness_seed()">Regen</button></div>
</div> </div>
<div style="padding:3px;" class="justifyleft settinglabel">Classifier-Free Guidance <span class="helpicon">?<span <div style="padding:3px;" class="justifyleft settinglabel">Classifier-Free Guidance <span class="helpicon">?<span
class="helptext">Functions as a negative prompt when Guidance Scale is above 1.</span></span> class="helptext">Functions as a negative prompt when Guidance Scale is above 1.</span></span>
@ -22818,7 +23011,7 @@ Current version indicated by LITEVER below.
</div> </div>
<div id="expandguidance" class="hidden"> <div id="expandguidance" class="hidden">
<div class="color_red hidden" id="noguidance">Classifier-Free Guidance may be unavailable.</div> <div class="color_red hidden" id="noguidance">Classifier-Free Guidance may be unavailable.</div>
<div style="color:#ffffff;">Classifier-Free Guidance prompt functions as a negative prompt when Guidance Scale is above 1, and a positive prompt at Guidance Scale is above 1. Disabled if scale is exactly 1 or CFG prompt is blank.</em><br></div> <div style="color:#ffffff;">Classifier-Free Guidance prompt functions as a negative prompt when Guidance Scale is above 1, and a positive prompt at Guidance Scale is below 1. Disabled if scale is exactly 1 or CFG prompt is blank.</em><br></div>
<div style="display: flex; column-gap: 4px; margin-top: 4px; margin-bottom: 4px;"> <div style="display: flex; column-gap: 4px; margin-top: 4px; margin-bottom: 4px;">
<input class="form-control menuinput_inline" type="text" placeholder="Enter CFG Prompt" value="" id="guidance_prompt"> <input class="form-control menuinput_inline" type="text" placeholder="Enter CFG Prompt" value="" id="guidance_prompt">
<div style="padding:1px" class="settinglabel">Scale<br>(0-5): </div><input class="form-control menuinput_inline" style="margin-left:4px;width:70px;" inputmode="numeric" placeholder="(Off)" value="" id="guidance_scale"></div> <div style="padding:1px" class="settinglabel">Scale<br>(0-5): </div><input class="form-control menuinput_inline" style="margin-left:4px;width:70px;" inputmode="numeric" placeholder="(Off)" value="" id="guidance_scale"></div>
@ -23515,7 +23708,6 @@ Current version indicated by LITEVER below.
<div><input type="checkbox" id="useoainonstandard" title="Send Non-Standard Fields"> <div><input type="checkbox" id="useoainonstandard" title="Send Non-Standard Fields">
<div class="box-label">Non-Standard Fields</div></div> <div class="box-label">Non-Standard Fields</div></div>
</div> </div>
<div class="todoremove color_yellow">Looking for the Streaming Toggle? It's now in Advanced Settings -> Streaming!</div>
<span id="useoaichatcomplbox" class="hidden" onload="toggleoaichatcompl();"> <span id="useoaichatcomplbox" class="hidden" onload="toggleoaichatcompl();">
<br> <br>
Main Message Role: Main Message Role:
@ -23652,7 +23844,6 @@ Current version indicated by LITEVER below.
<div class="box-label">Allow Thinking</div> <div class="box-label">Allow Thinking</div>
</div> </div>
</div> </div>
<div class="todoremove color_yellow">Looking for the Streaming Toggle? It's now in Advanced Settings -> Streaming!</div>
</div> </div>
<div id="coherecustom" class="menutext hidden"> <div id="coherecustom" class="menutext hidden">
Uses Cohere's models through their own API.<br><br> Uses Cohere's models through their own API.<br><br>

View file

@ -168,6 +168,11 @@ static struct llama_model * llama_model_load_from_file_impl(
struct llama_model_params params) { struct llama_model_params params) {
ggml_time_init(); ggml_time_init();
if (!params.vocab_only && ggml_backend_reg_count() == 0) {
LLAMA_LOG_ERROR("%s: no backends are loaded. hint: use ggml_backend_load() or ggml_backend_load_all() to load a backend before calling this function\n", __func__);
return nullptr;
}
unsigned cur_percentage = 0; unsigned cur_percentage = 0;
if (params.progress_callback == NULL) { if (params.progress_callback == NULL) {
params.progress_callback_user_data = &cur_percentage; params.progress_callback_user_data = &cur_percentage;

Binary file not shown.

View file

@ -2251,6 +2251,14 @@ struct server_context {
slot.has_next_token = true; slot.has_next_token = true;
} }
// if context shifting is disabled, make sure that we don't run out of context
if (!params_base.ctx_shift && slot.n_past + 1 >= slot.n_ctx) {
slot.stop = STOP_TYPE_LIMIT;
slot.has_next_token = false;
SLT_DBG(slot, "stopped due to running out of context, n_past = %d, n_ctx = %d\n", slot.n_past, slot.n_ctx);
}
// check the limits // check the limits
if (slot.n_decoded > 0 && slot.has_next_token && !slot.has_budget(params_base)) { if (slot.n_decoded > 0 && slot.has_next_token && !slot.has_budget(params_base)) {
slot.stop = STOP_TYPE_LIMIT; slot.stop = STOP_TYPE_LIMIT;
@ -4340,6 +4348,7 @@ int main(int argc, char ** argv) {
json data = oaicompat_completion_params_parse( json data = oaicompat_completion_params_parse(
body, body,
params.use_jinja, params.use_jinja,
params.prefill_assistant,
params.reasoning_format, params.reasoning_format,
ctx_server.chat_templates.get(), ctx_server.chat_templates.get(),
ctx_server.mctx, ctx_server.mctx,
@ -4361,6 +4370,7 @@ int main(int argc, char ** argv) {
json data = oaicompat_completion_params_parse( json data = oaicompat_completion_params_parse(
body, body,
params.use_jinja, params.use_jinja,
params.prefill_assistant,
params.reasoning_format, params.reasoning_format,
ctx_server.chat_templates.get(), ctx_server.chat_templates.get(),
ctx_server.mctx, ctx_server.mctx,

View file

@ -65,3 +65,21 @@ def test_ctx_shift_disabled_long_prompt():
assert res.status_code != 200 assert res.status_code != 200
assert "error" in res.body assert "error" in res.body
assert "exceeds the available context size" in res.body["error"]["message"] assert "exceeds the available context size" in res.body["error"]["message"]
def test_ctx_shift_disabled_stream():
global server
server.disable_ctx_shift = True
server.start()
res = server.make_stream_request("POST", "/v1/completions", data={
"n_predict": 256,
"prompt": "Once",
"stream": True,
})
content = ""
for data in res:
choice = data["choices"][0]
if choice["finish_reason"] == "length":
assert len(content) > 0
else:
assert choice["finish_reason"] is None
content += choice["text"]

View file

@ -583,6 +583,7 @@ static json oaicompat_completion_params_parse(const json & body) {
static json oaicompat_completion_params_parse( static json oaicompat_completion_params_parse(
const json & body, /* openai api json semantics */ const json & body, /* openai api json semantics */
bool use_jinja, bool use_jinja,
bool prefill_assistant,
common_reasoning_format reasoning_format, common_reasoning_format reasoning_format,
const struct common_chat_templates * tmpls, const struct common_chat_templates * tmpls,
bool allow_non_text, bool allow_non_text,
@ -732,7 +733,7 @@ static json oaicompat_completion_params_parse(
// if the assistant message appears at the end of list, we do not add end-of-turn token // if the assistant message appears at the end of list, we do not add end-of-turn token
// for ex. this can be useful to modify the reasoning process in reasoning models // for ex. this can be useful to modify the reasoning process in reasoning models
bool prefill_assistant_message = !inputs.messages.empty() && inputs.messages.back().role == "assistant"; bool prefill_assistant_message = !inputs.messages.empty() && inputs.messages.back().role == "assistant" && prefill_assistant;
common_chat_msg last_message; common_chat_msg last_message;
if (prefill_assistant_message) { if (prefill_assistant_message) {
last_message = inputs.messages.back(); last_message = inputs.messages.back();

View file

@ -28,13 +28,13 @@ function AppLayout() {
return ( return (
<> <>
<Sidebar /> <Sidebar />
<div <main
className="drawer-content grow flex flex-col h-screen w-screen mx-auto px-4 overflow-auto bg-base-100" className="drawer-content grow flex flex-col h-screen w-screen mx-auto px-4 overflow-auto bg-base-100"
id="main-scroll" id="main-scroll"
> >
<Header /> <Header />
<Outlet /> <Outlet />
</div> </main>
{ {
<SettingDialog <SettingDialog
show={showSettings} show={showSettings}

View file

@ -18,16 +18,26 @@ export default function ChatInputExtraContextItem({
if (!items) return null; if (!items) return null;
return ( return (
<div className="flex flex-row gap-4 overflow-x-auto py-2 px-1 mb-1"> <div
className="flex flex-row gap-4 overflow-x-auto py-2 px-1 mb-1"
role="group"
aria-description="Selected files"
>
{items.map((item, i) => ( {items.map((item, i) => (
<div <div
className="indicator" className="indicator"
key={i} key={i}
onClick={() => clickToShow && setShow(i)} onClick={() => clickToShow && setShow(i)}
tabIndex={0}
aria-description={
clickToShow ? `Click to show: ${item.name}` : undefined
}
role={clickToShow ? 'button' : 'menuitem'}
> >
{removeItem && ( {removeItem && (
<div className="indicator-item indicator-top"> <div className="indicator-item indicator-top">
<button <button
aria-label="Remove file"
className="btn btn-neutral btn-sm w-4 h-4 p-0 rounded-full" className="btn btn-neutral btn-sm w-4 h-4 p-0 rounded-full"
onClick={() => removeItem(i)} onClick={() => removeItem(i)}
> >
@ -46,13 +56,16 @@ export default function ChatInputExtraContextItem({
<> <>
<img <img
src={item.base64Url} src={item.base64Url}
alt={item.name} alt={`Preview image for ${item.name}`}
className="w-14 h-14 object-cover rounded-md" className="w-14 h-14 object-cover rounded-md"
/> />
</> </>
) : ( ) : (
<> <>
<div className="w-14 h-14 flex items-center justify-center"> <div
className="w-14 h-14 flex items-center justify-center"
aria-description="Document icon"
>
<DocumentTextIcon className="h-8 w-14 text-base-content/50" /> <DocumentTextIcon className="h-8 w-14 text-base-content/50" />
</div> </div>
@ -66,16 +79,25 @@ export default function ChatInputExtraContextItem({
))} ))}
{showingItem && ( {showingItem && (
<dialog className="modal modal-open"> <dialog
className="modal modal-open"
aria-description={`Preview ${showingItem.name}`}
>
<div className="modal-box"> <div className="modal-box">
<div className="flex justify-between items-center mb-4"> <div className="flex justify-between items-center mb-4">
<b>{showingItem.name ?? 'Extra content'}</b> <b>{showingItem.name ?? 'Extra content'}</b>
<button className="btn btn-ghost btn-sm"> <button
className="btn btn-ghost btn-sm"
aria-label="Close preview dialog"
>
<XMarkIcon className="h-5 w-5" onClick={() => setShow(-1)} /> <XMarkIcon className="h-5 w-5" onClick={() => setShow(-1)} />
</button> </button>
</div> </div>
{showingItem.type === 'imageFile' ? ( {showingItem.type === 'imageFile' ? (
<img src={showingItem.base64Url} alt={showingItem.name} /> <img
src={showingItem.base64Url}
alt={`Preview image for ${showingItem.name}`}
/>
) : ( ) : (
<div className="overflow-x-auto"> <div className="overflow-x-auto">
<pre className="whitespace-pre-wrap break-words text-sm"> <pre className="whitespace-pre-wrap break-words text-sm">

View file

@ -83,13 +83,20 @@ export default function ChatMessage({
if (!viewingChat) return null; if (!viewingChat) return null;
const isUser = msg.role === 'user';
return ( return (
<div className="group" id={id}> <div
className="group"
id={id}
role="group"
aria-description={`Message from ${msg.role}`}
>
<div <div
className={classNames({ className={classNames({
chat: true, chat: true,
'chat-start': msg.role !== 'user', 'chat-start': !isUser,
'chat-end': msg.role === 'user', 'chat-end': isUser,
})} })}
> >
{msg.extra && msg.extra.length > 0 && ( {msg.extra && msg.extra.length > 0 && (
@ -99,7 +106,7 @@ export default function ChatMessage({
<div <div
className={classNames({ className={classNames({
'chat-bubble markdown': true, 'chat-bubble markdown': true,
'chat-bubble bg-transparent': msg.role !== 'user', 'chat-bubble bg-transparent': !isUser,
})} })}
> >
{/* textarea for editing message */} {/* textarea for editing message */}
@ -142,7 +149,7 @@ export default function ChatMessage({
) : ( ) : (
<> <>
{/* render message as markdown */} {/* render message as markdown */}
<div dir="auto"> <div dir="auto" tabIndex={0}>
{thought && ( {thought && (
<ThoughtProcess <ThoughtProcess
isThinking={!!isThinking && !!isPending} isThinking={!!isThinking && !!isPending}
@ -196,13 +203,18 @@ export default function ChatMessage({
})} })}
> >
{siblingLeafNodeIds && siblingLeafNodeIds.length > 1 && ( {siblingLeafNodeIds && siblingLeafNodeIds.length > 1 && (
<div className="flex gap-1 items-center opacity-60 text-sm"> <div
className="flex gap-1 items-center opacity-60 text-sm"
role="navigation"
aria-description={`Message version ${siblingCurrIdx + 1} of ${siblingLeafNodeIds.length}`}
>
<button <button
className={classNames({ className={classNames({
'btn btn-sm btn-ghost p-1': true, 'btn btn-sm btn-ghost p-1': true,
'opacity-20': !prevSibling, 'opacity-20': !prevSibling,
})} })}
onClick={() => prevSibling && onChangeSibling(prevSibling)} onClick={() => prevSibling && onChangeSibling(prevSibling)}
aria-label="Previous message version"
> >
<ChevronLeftIcon className="h-4 w-4" /> <ChevronLeftIcon className="h-4 w-4" />
</button> </button>
@ -215,6 +227,7 @@ export default function ChatMessage({
'opacity-20': !nextSibling, 'opacity-20': !nextSibling,
})} })}
onClick={() => nextSibling && onChangeSibling(nextSibling)} onClick={() => nextSibling && onChangeSibling(nextSibling)}
aria-label="Next message version"
> >
<ChevronRightIcon className="h-4 w-4" /> <ChevronRightIcon className="h-4 w-4" />
</button> </button>
@ -223,7 +236,7 @@ export default function ChatMessage({
{/* user message */} {/* user message */}
{msg.role === 'user' && ( {msg.role === 'user' && (
<BtnWithTooltips <BtnWithTooltips
className="btn-mini show-on-hover w-8 h-8" className="btn-mini w-8 h-8"
onClick={() => setEditingContent(msg.content)} onClick={() => setEditingContent(msg.content)}
disabled={msg.content === null} disabled={msg.content === null}
tooltipsContent="Edit message" tooltipsContent="Edit message"
@ -236,7 +249,7 @@ export default function ChatMessage({
<> <>
{!isPending && ( {!isPending && (
<BtnWithTooltips <BtnWithTooltips
className="btn-mini show-on-hover w-8 h-8" className="btn-mini w-8 h-8"
onClick={() => { onClick={() => {
if (msg.content !== null) { if (msg.content !== null) {
onRegenerateMessage(msg as Message); onRegenerateMessage(msg as Message);
@ -250,10 +263,7 @@ export default function ChatMessage({
)} )}
</> </>
)} )}
<CopyButton <CopyButton className="btn-mini w-8 h-8" content={msg.content} />
className="btn-mini show-on-hover w-8 h-8"
content={msg.content}
/>
</div> </div>
)} )}
</div> </div>
@ -271,6 +281,8 @@ function ThoughtProcess({
}) { }) {
return ( return (
<div <div
role="button"
aria-label="Toggle thought process display"
tabIndex={0} tabIndex={0}
className={classNames({ className={classNames({
'collapse bg-none': true, 'collapse bg-none': true,
@ -292,7 +304,11 @@ function ThoughtProcess({
)} )}
</div> </div>
</div> </div>
<div className="collapse-content text-base-content/70 text-sm p-1"> <div
className="collapse-content text-base-content/70 text-sm p-1"
tabIndex={0}
aria-description="Thought process content"
>
<div className="border-l-2 border-base-content/20 pl-4 mb-4"> <div className="border-l-2 border-base-content/20 pl-4 mb-4">
<MarkdownDisplay content={content} /> <MarkdownDisplay content={content} />
</div> </div>

View file

@ -279,7 +279,11 @@ export default function ChatScreen() {
function ServerInfo() { function ServerInfo() {
const { serverProps } = useAppContext(); const { serverProps } = useAppContext();
return ( return (
<div className="card card-sm shadow-sm border-1 border-base-content/20 text-base-content/70 mb-6"> <div
className="card card-sm shadow-sm border-1 border-base-content/20 text-base-content/70 mb-6"
tabIndex={0}
aria-description="Server information"
>
<div className="card-body"> <div className="card-body">
<b>Server Info</b> <b>Server Info</b>
<p> <p>
@ -311,6 +315,8 @@ function ChatInput({
return ( return (
<div <div
role="group"
aria-label="Chat input"
className={classNames({ className={classNames({
'flex items-end pt-8 pb-6 sticky bottom-0 bg-base-100': true, 'flex items-end pt-8 pb-6 sticky bottom-0 bg-base-100': true,
'opacity-50': isDrag, // simply visual feedback to inform user that the file will be accepted 'opacity-50': isDrag, // simply visual feedback to inform user that the file will be accepted
@ -400,13 +406,15 @@ function ChatInput({
'btn w-8 h-8 p-0 rounded-full': true, 'btn w-8 h-8 p-0 rounded-full': true,
'btn-disabled': isGenerating, 'btn-disabled': isGenerating,
})} })}
aria-label="Upload file"
tabIndex={0}
role="button"
> >
<PaperClipIcon className="h-5 w-5" /> <PaperClipIcon className="h-5 w-5" />
</label> </label>
<input <input
id="file-upload" id="file-upload"
type="file" type="file"
className="hidden"
disabled={isGenerating} disabled={isGenerating}
{...getInputProps()} {...getInputProps()}
hidden hidden
@ -422,6 +430,7 @@ function ChatInput({
<button <button
className="btn btn-primary w-8 h-8 p-0 rounded-full" className="btn btn-primary w-8 h-8 p-0 rounded-full"
onClick={onSend} onClick={onSend}
aria-label="Send message"
> >
<ArrowUpIcon className="h-5 w-5" /> <ArrowUpIcon className="h-5 w-5" />
</button> </button>

View file

@ -38,8 +38,12 @@ export default function Header() {
{/* action buttons (top right) */} {/* action buttons (top right) */}
<div className="flex items-center"> <div className="flex items-center">
<div className="tooltip tooltip-bottom" data-tip="Settings"> <div
<button className="btn" onClick={() => setShowSettings(true)}> className="tooltip tooltip-bottom"
data-tip="Settings"
onClick={() => setShowSettings(true)}
>
<button className="btn" aria-hidden={true}>
{/* settings button */} {/* settings button */}
<Cog8ToothIcon className="w-5 h-5" /> <Cog8ToothIcon className="w-5 h-5" />
</button> </button>

View file

@ -335,14 +335,22 @@ export default function SettingDialog({
}; };
return ( return (
<dialog className={classNames({ modal: true, 'modal-open': show })}> <dialog
className={classNames({ modal: true, 'modal-open': show })}
aria-label="Settings dialog"
>
<div className="modal-box w-11/12 max-w-3xl"> <div className="modal-box w-11/12 max-w-3xl">
<h3 className="text-lg font-bold mb-6">Settings</h3> <h3 className="text-lg font-bold mb-6">Settings</h3>
<div className="flex flex-col md:flex-row h-[calc(90vh-12rem)]"> <div className="flex flex-col md:flex-row h-[calc(90vh-12rem)]">
{/* Left panel, showing sections - Desktop version */} {/* Left panel, showing sections - Desktop version */}
<div className="hidden md:flex flex-col items-stretch pr-4 mr-4 border-r-2 border-base-200"> <div
className="hidden md:flex flex-col items-stretch pr-4 mr-4 border-r-2 border-base-200"
role="complementary"
aria-description="Settings sections"
tabIndex={0}
>
{SETTING_SECTIONS.map((section, idx) => ( {SETTING_SECTIONS.map((section, idx) => (
<div <button
key={idx} key={idx}
className={classNames({ className={classNames({
'btn btn-ghost justify-start font-normal w-44 mb-1': true, 'btn btn-ghost justify-start font-normal w-44 mb-1': true,
@ -352,12 +360,16 @@ export default function SettingDialog({
dir="auto" dir="auto"
> >
{section.title} {section.title}
</div> </button>
))} ))}
</div> </div>
{/* Left panel, showing sections - Mobile version */} {/* Left panel, showing sections - Mobile version */}
<div className="md:hidden flex flex-row gap-2 mb-4"> {/* This menu is skipped on a11y, otherwise it's repeated the desktop version */}
<div
className="md:hidden flex flex-row gap-2 mb-4"
aria-disabled={true}
>
<details className="dropdown"> <details className="dropdown">
<summary className="btn bt-sm w-full m-1"> <summary className="btn bt-sm w-full m-1">
{SETTING_SECTIONS[sectionIdx].title} {SETTING_SECTIONS[sectionIdx].title}

View file

@ -50,44 +50,72 @@ export default function Sidebar() {
id="toggle-drawer" id="toggle-drawer"
type="checkbox" type="checkbox"
className="drawer-toggle" className="drawer-toggle"
aria-label="Toggle sidebar"
defaultChecked defaultChecked
/> />
<div className="drawer-side h-screen lg:h-screen z-50 lg:max-w-64"> <div
className="drawer-side h-screen lg:h-screen z-50 lg:max-w-64"
role="complementary"
aria-label="Sidebar"
tabIndex={0}
>
<label <label
htmlFor="toggle-drawer" htmlFor="toggle-drawer"
aria-label="close sidebar" aria-label="Close sidebar"
className="drawer-overlay" className="drawer-overlay"
></label> ></label>
<a
href="#main-scroll"
className="absolute -left-80 top-0 w-1 h-1 overflow-hidden"
>
Skip to main content
</a>
<div className="flex flex-col bg-base-200 min-h-full max-w-64 py-4 px-4"> <div className="flex flex-col bg-base-200 min-h-full max-w-64 py-4 px-4">
<div className="flex flex-row items-center justify-between mb-4 mt-4"> <div className="flex flex-row items-center justify-between mb-4 mt-4">
<h2 className="font-bold ml-4">Conversations</h2> <h2 className="font-bold ml-4" role="heading">
Conversations
</h2>
{/* close sidebar button */} {/* close sidebar button */}
<label htmlFor="toggle-drawer" className="btn btn-ghost lg:hidden"> <label
htmlFor="toggle-drawer"
className="btn btn-ghost lg:hidden"
aria-label="Close sidebar"
role="button"
tabIndex={0}
>
<XMarkIcon className="w-5 h-5" /> <XMarkIcon className="w-5 h-5" />
</label> </label>
</div> </div>
{/* new conversation button */} {/* new conversation button */}
<div <button
className={classNames({ className={classNames({
'btn btn-ghost justify-start px-2': true, 'btn btn-ghost justify-start px-2': true,
'btn-soft': !currConv, 'btn-soft': !currConv,
})} })}
onClick={() => navigate('/')} onClick={() => navigate('/')}
aria-label="New conversation"
> >
<PencilSquareIcon className="w-5 h-5" /> <PencilSquareIcon className="w-5 h-5" />
New conversation New conversation
</div> </button>
{/* list of conversations */} {/* list of conversations */}
{groupedConv.map((group, i) => ( {groupedConv.map((group, i) => (
<div key={i}> <div key={i} role="group">
{/* group name (by date) */} {/* group name (by date) */}
{group.title ? ( {group.title ? (
// we use btn class here to make sure that the padding/margin are aligned with the other items // we use btn class here to make sure that the padding/margin are aligned with the other items
<b className="btn btn-ghost btn-xs bg-none btn-disabled block text-xs text-base-content text-start px-2 mb-0 mt-6 font-bold"> <b
className="btn btn-ghost btn-xs bg-none btn-disabled block text-xs text-base-content text-start px-2 mb-0 mt-6 font-bold"
role="note"
aria-description={group.title}
tabIndex={0}
>
{group.title} {group.title}
</b> </b>
) : ( ) : (
@ -184,20 +212,23 @@ function ConversationItem({
}) { }) {
return ( return (
<div <div
role="menuitem"
tabIndex={0}
aria-label={conv.name}
className={classNames({ className={classNames({
'group flex flex-row btn btn-ghost justify-start items-center font-normal px-2 h-9': 'group flex flex-row btn btn-ghost justify-start items-center font-normal px-2 h-9':
true, true,
'btn-soft': isCurrConv, 'btn-soft': isCurrConv,
})} })}
> >
<div <button
key={conv.id} key={conv.id}
className="w-full overflow-hidden truncate text-start" className="w-full overflow-hidden truncate text-start"
onClick={onSelect} onClick={onSelect}
dir="auto" dir="auto"
> >
{conv.name} {conv.name}
</div> </button>
<div className="dropdown dropdown-end h-5"> <div className="dropdown dropdown-end h-5">
<BtnWithTooltips <BtnWithTooltips
// on mobile, we always show the ellipsis icon // on mobile, we always show the ellipsis icon
@ -211,22 +242,23 @@ function ConversationItem({
</BtnWithTooltips> </BtnWithTooltips>
{/* dropdown menu */} {/* dropdown menu */}
<ul <ul
aria-label="More options"
tabIndex={0} tabIndex={0}
className="dropdown-content menu bg-base-100 rounded-box z-[1] p-2 shadow" className="dropdown-content menu bg-base-100 rounded-box z-[1] p-2 shadow"
> >
<li onClick={onRename}> <li onClick={onRename} tabIndex={0}>
<a> <a>
<PencilIcon className="w-4 h-4" /> <PencilIcon className="w-4 h-4" />
Rename Rename
</a> </a>
</li> </li>
<li onClick={onDownload}> <li onClick={onDownload} tabIndex={0}>
<a> <a>
<ArrowDownTrayIcon className="w-4 h-4" /> <ArrowDownTrayIcon className="w-4 h-4" />
Download Download
</a> </a>
</li> </li>
<li className="text-error" onClick={onDelete}> <li className="text-error" onClick={onDelete} tabIndex={0}>
<a> <a>
<TrashIcon className="w-4 h-4" /> <TrashIcon className="w-4 h-4" />
Delete Delete

View file

@ -34,9 +34,6 @@ html {
/* TODO: fix markdown table */ /* TODO: fix markdown table */
} }
.show-on-hover {
@apply md:opacity-0 md:group-hover:opacity-100;
}
.btn-mini { .btn-mini {
@apply cursor-pointer; @apply cursor-pointer;
} }

View file

@ -52,13 +52,20 @@ export function BtnWithTooltips({
tooltipsContent: string; tooltipsContent: string;
disabled?: boolean; disabled?: boolean;
}) { }) {
// the onClick handler is on the container, so screen readers can safely ignore the inner button
// this prevents the label from being read twice
return ( return (
<div className="tooltip tooltip-bottom" data-tip={tooltipsContent}> <div
className="tooltip tooltip-bottom"
data-tip={tooltipsContent}
role="button"
onClick={onClick}
>
<button <button
className={`${className ?? ''} flex items-center justify-center`} className={`${className ?? ''} flex items-center justify-center`}
onClick={onClick}
disabled={disabled} disabled={disabled}
onMouseLeave={onMouseLeave} onMouseLeave={onMouseLeave}
aria-hidden={true}
> >
{children} {children}
</button> </button>