mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-05-17 04:09:19 +00:00
Merge branch 'upstream' into concedo_experimental
# Conflicts: # .github/workflows/build.yml # .github/workflows/release.yml # .github/workflows/server.yml # README.md # docs/build.md # docs/install.md # ggml/src/ggml-cpu/CMakeLists.txt # ggml/src/ggml-opencl/CMakeLists.txt # ggml/src/ggml-opencl/ggml-opencl.cpp # ggml/src/ggml-sycl/ggml-sycl.cpp # ggml/src/ggml-sycl/mmvq.cpp # ggml/src/ggml-sycl/vecdotq.hpp # tests/test-backend-ops.cpp # tests/test-chat.cpp
This commit is contained in:
commit
bc89b465a8
35 changed files with 1070 additions and 288 deletions
|
|
@ -2871,6 +2871,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
"(default: deepseek)",
|
||||
[](common_params & params, const std::string & value) {
|
||||
/**/ if (value == "deepseek") { params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; }
|
||||
else if (value == "deepseek-legacy") { params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY; }
|
||||
else if (value == "none") { params.reasoning_format = COMMON_REASONING_FORMAT_NONE; }
|
||||
else { throw std::invalid_argument("invalid value"); }
|
||||
}
|
||||
|
|
|
|||
|
|
@ -82,10 +82,10 @@ json common_chat_msg::to_json_oaicompat() const
|
|||
|
||||
std::vector<common_chat_msg_diff> common_chat_msg_diff::compute_diffs(const common_chat_msg & previous_msg, const common_chat_msg & new_msg) {
|
||||
std::vector<common_chat_msg_diff> diffs;
|
||||
// if (previous_msg.reasoning_content != current.reasoning_content) {
|
||||
// auto & diff = diffs.emplace_back();
|
||||
// diff.reasoning_content_delta = string_diff(previous_msg.reasoning_content, current.reasoning_content);
|
||||
// }
|
||||
if (previous_msg.reasoning_content != new_msg.reasoning_content) {
|
||||
auto & diff = diffs.emplace_back();
|
||||
diff.reasoning_content_delta = string_diff(previous_msg.reasoning_content, new_msg.reasoning_content);
|
||||
}
|
||||
if (previous_msg.content != new_msg.content) {
|
||||
auto & diff = diffs.emplace_back();
|
||||
diff.content_delta = string_diff(previous_msg.content, new_msg.content);
|
||||
|
|
@ -385,9 +385,9 @@ json common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & t
|
|||
|
||||
template <> json common_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff) {
|
||||
json delta = json::object();
|
||||
// if (!diff.reasoning_content_delta.empty()) {
|
||||
// delta["reasoning_content"] = msg.reasoning_content;
|
||||
// }
|
||||
if (!diff.reasoning_content_delta.empty()) {
|
||||
delta["reasoning_content"] = diff.reasoning_content_delta;
|
||||
}
|
||||
if (!diff.content_delta.empty()) {
|
||||
delta["content"] = diff.content_delta;
|
||||
}
|
||||
|
|
@ -598,6 +598,7 @@ const char * common_reasoning_format_name(common_reasoning_format format) {
|
|||
switch (format) {
|
||||
case COMMON_REASONING_FORMAT_NONE: return "none";
|
||||
case COMMON_REASONING_FORMAT_DEEPSEEK: return "deepseek";
|
||||
case COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY: return "deepseek-legacy";
|
||||
default:
|
||||
throw std::runtime_error("Unknown reasoning format");
|
||||
}
|
||||
|
|
|
|||
|
|
@ -70,7 +70,7 @@ struct common_chat_msg {
|
|||
};
|
||||
|
||||
struct common_chat_msg_diff {
|
||||
// std::string reasoning_content_delta;
|
||||
std::string reasoning_content_delta;
|
||||
std::string content_delta;
|
||||
size_t tool_call_index = std::string::npos;
|
||||
common_chat_tool_call tool_call_delta;
|
||||
|
|
|
|||
|
|
@ -211,7 +211,8 @@ struct common_params_vocoder {
|
|||
|
||||
enum common_reasoning_format {
|
||||
COMMON_REASONING_FORMAT_NONE,
|
||||
COMMON_REASONING_FORMAT_DEEPSEEK, // Extract thinking tag contents and return as `message.reasoning_content`
|
||||
COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY, // Extract thinking tag contents and return as `message.reasoning_content`, or leave inline in <think> tags in stream mode
|
||||
COMMON_REASONING_FORMAT_DEEPSEEK, // Extract thinking tag contents and return as `message.reasoning_content`, including in streaming deltas.
|
||||
};
|
||||
|
||||
struct common_params {
|
||||
|
|
|
|||
|
|
@ -8132,8 +8132,8 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
|
|||
#define WKV_VECTOR_SIZE 4
|
||||
#endif
|
||||
|
||||
int wkv_vector_size;
|
||||
#ifdef WKV_VECTOR_SIZE
|
||||
int wkv_vector_size;
|
||||
#if defined(__ARM_FEATURE_SVE)
|
||||
wkv_vector_size = svcntw();
|
||||
#else
|
||||
|
|
@ -8348,8 +8348,8 @@ static void ggml_compute_forward_gla_f32(
|
|||
#define GLA_VECTOR_SIZE 4
|
||||
#endif
|
||||
|
||||
int gla_vector_size;
|
||||
#ifdef GLA_VECTOR_SIZE
|
||||
int gla_vector_size;
|
||||
#if defined(__ARM_FEATURE_SVE)
|
||||
gla_vector_size = svcntw();
|
||||
#else
|
||||
|
|
|
|||
|
|
@ -652,9 +652,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|||
float KQ_max_scale[cols_per_thread];
|
||||
#pragma unroll
|
||||
for (int col = 0; col < cols_per_thread; ++col) {
|
||||
KQ_max_scale[col] = expf(KQ_max[col] - KQ_max_new[col]);
|
||||
const float KQ_max_diff = KQ_max[col] - KQ_max_new[col];
|
||||
KQ_max_scale[col] = expf(KQ_max_diff);
|
||||
KQ_max[col] = KQ_max_new[col];
|
||||
|
||||
*((uint32_t *) &KQ_max_scale[col]) *= KQ_max_diff >= SOFTMAX_FTZ_THRESHOLD;
|
||||
|
||||
// Scale previous KQ_rowsum to account for a potential increase in KQ_max:
|
||||
KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_rowsum_add[col];
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4766,6 +4766,8 @@ static bool ggml_metal_encode_node(
|
|||
GGML_ASSERT(nqptg % 8 == 0);
|
||||
GGML_ASSERT(ncpsg % 32 == 0);
|
||||
|
||||
const int is_q = ggml_is_quantized(src1->type) ? 1 : 0;
|
||||
|
||||
// 2*(2*ncpsg + nqptg)*(nsg)
|
||||
// ncpsg soft_max values + ncpsg mask values + a diagonal scaling matrix (in float)
|
||||
//
|
||||
|
|
@ -4773,7 +4775,7 @@ static bool ggml_metal_encode_node(
|
|||
// the shared memory needed for the simdgroups to load the KV cache
|
||||
// each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
|
||||
//
|
||||
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*(2*ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16))
|
||||
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(2*ne00 + 2*(2*ncpsg + nqptg)*(nsg)) + is_q*(16*32*(nsg)))*(sizeof(float)/2), 16))
|
||||
|
||||
int64_t nsgmax = 2;
|
||||
|
||||
|
|
@ -4810,9 +4812,9 @@ static bool ggml_metal_encode_node(
|
|||
// and store the soft_max values and the mask
|
||||
//
|
||||
// ne00*(nsg)
|
||||
// each simdgroup has a full f16 head vector in shared mem to accumulate results
|
||||
// each simdgroup has a full f32 head vector in shared mem to accumulate results
|
||||
//
|
||||
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + ne20*(nsg))*(sizeof(float)/2), 16))
|
||||
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + 2*ne20*(nsg))*(sizeof(float)/2), 16))
|
||||
|
||||
int64_t nsgmax = 2;
|
||||
while (true) {
|
||||
|
|
|
|||
|
|
@ -3328,14 +3328,14 @@ kernel void kernel_flash_attn_ext(
|
|||
constexpr short NW = N_SIMDWIDTH;
|
||||
constexpr short SH = (2*C + Q); // shared memory per simdgroup (s_t == float)
|
||||
|
||||
const short TS = nsg*SH; // shared memory size per query in (s_t == float)
|
||||
const short T = DK + 2*TS; // shared memory size per query in (half)
|
||||
const short TS = nsg*SH; // shared memory size per query in (s_t == float)
|
||||
const short T = 2*DK + 2*TS; // shared memory size per query in (half)
|
||||
|
||||
threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
|
||||
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t
|
||||
threadgroup o_t * so = (threadgroup o_t *) (shmem_f16 + 0*DK); // reuse query data for accumulation
|
||||
threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0*DK); // same as above but in o4_t
|
||||
threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2*sgitg*SH + Q*DK); // scratch buffer for attention, mask and diagonal matrix
|
||||
threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
|
||||
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t
|
||||
threadgroup o_t * so = (threadgroup o_t *) (shmem_f16 + 0*DK); // reuse query data for accumulation
|
||||
threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0*DK); // same as above but in o4_t
|
||||
threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2*sgitg*SH + 2*Q*DK); // scratch buffer for attention, mask and diagonal matrix
|
||||
|
||||
threadgroup k_t * sk = (threadgroup k_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory
|
||||
threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t
|
||||
|
|
@ -3354,7 +3354,7 @@ kernel void kernel_flash_attn_ext(
|
|||
if (iq1 + j < args.ne01) {
|
||||
sq4[j*DK4 + i] = (q4_t) q4[i];
|
||||
} else {
|
||||
sq4[j*DK4 + i] = (q4_t) 0.0f;
|
||||
sq4[j*DK4 + i] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -3634,9 +3634,6 @@ kernel void kernel_flash_attn_ext(
|
|||
|
||||
// reduce the warps sequentially
|
||||
for (ushort sg = 1; sg < nsg; ++sg) {
|
||||
float S = { 0.0f };
|
||||
float M = { -__FLT_MAX__/2 };
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// each simdgroup stores its output to shared memory, reusing sq
|
||||
|
|
@ -3657,12 +3654,12 @@ kernel void kernel_flash_attn_ext(
|
|||
const float M0 = ss[j*TS + 1];
|
||||
const float M1 = ss[j*TS + sg*SH + 1];
|
||||
|
||||
M = max(M0, M1);
|
||||
const float M = max(M0, M1);
|
||||
|
||||
const float ms0 = exp(M0 - M);
|
||||
const float ms1 = exp(M1 - M);
|
||||
|
||||
S = S0*ms0 + S1*ms1;
|
||||
const float S = S0*ms0 + S1*ms1;
|
||||
|
||||
if (tiisg == 0) {
|
||||
ss[j*TS + 0] = S;
|
||||
|
|
@ -3701,16 +3698,18 @@ kernel void kernel_flash_attn_ext(
|
|||
}
|
||||
}
|
||||
|
||||
device float4 * dst4 = (device float4 *) dst;
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
threadgroup s_t * sf = (threadgroup s_t *) (shmem_f16 + 2*Q*DK);
|
||||
|
||||
// final rescale with 1/S and store to global memory
|
||||
if (sgitg == 0) {
|
||||
for (short j = 0; j < Q && iq1 + j < args.ne01; ++j) {
|
||||
const float S = ss[j*TS + 0];
|
||||
for (short j = sgitg; j < Q && iq1 + j < args.ne01; j += nsg) {
|
||||
const float S = 1.0f/sf[j*TS + 0];
|
||||
|
||||
for (short i = tiisg; i < DV4; i += NW) {
|
||||
dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*DV4 + i] = (float4) so4[j*DV4 + i]/S;
|
||||
}
|
||||
device float4 * dst4 = (device float4 *) dst + ((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*DV4;
|
||||
|
||||
for (short i = tiisg; i < DV4; i += NW) {
|
||||
dst4[i] = (float4) so4[j*DV4 + i]*S;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -3719,12 +3718,22 @@ kernel void kernel_flash_attn_ext(
|
|||
// template to be able to explore different combinations
|
||||
//
|
||||
#define FA_TYPES \
|
||||
half, half4, simdgroup_half8x8, \
|
||||
half, half4x4, simdgroup_half8x8, \
|
||||
half, half4x4, simdgroup_half8x8, \
|
||||
float, simdgroup_float8x8, \
|
||||
float, simdgroup_float8x8, \
|
||||
half, half4, simdgroup_half8x8
|
||||
float, float4, simdgroup_float8x8, \
|
||||
half, half4x4, simdgroup_half8x8, \
|
||||
half, half4x4, simdgroup_half8x8, \
|
||||
float, simdgroup_float8x8, \
|
||||
float, simdgroup_float8x8, \
|
||||
float, float4, simdgroup_float8x8
|
||||
//half, half4, simdgroup_half8x8
|
||||
|
||||
#define FA_TYPES_BF \
|
||||
bfloat, bfloat4, simdgroup_bfloat8x8, \
|
||||
bfloat, bfloat4x4, simdgroup_bfloat8x8, \
|
||||
bfloat, bfloat4x4, simdgroup_bfloat8x8, \
|
||||
float, simdgroup_float8x8, \
|
||||
float, simdgroup_float8x8, \
|
||||
float, float4, simdgroup_float8x8
|
||||
//half, half4, simdgroup_half8x8
|
||||
|
||||
typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>) flash_attn_ext_t;
|
||||
|
||||
|
|
@ -3739,15 +3748,15 @@ template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_at
|
|||
template [[host_name("kernel_flash_attn_ext_f16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 576, 512>;
|
||||
|
||||
#if defined(GGML_METAL_USE_BF16)
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 80, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 96, 96>;
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 112, 112>;
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 128, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 192>;
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 256, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 576, 512>;
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 80, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 96, 96>;
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 112, 112>;
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 128, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 192>;
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 256, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 576, 512>;
|
||||
#endif
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64, 64>;
|
||||
|
|
@ -3801,6 +3810,7 @@ template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_at
|
|||
template [[host_name("kernel_flash_attn_ext_q8_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 576, 512>;
|
||||
|
||||
#undef FA_TYPES
|
||||
#undef FA_TYPES_BF
|
||||
|
||||
template<
|
||||
typename q4_t, // query types in shared memory
|
||||
|
|
@ -3847,12 +3857,12 @@ kernel void kernel_flash_attn_ext_vec(
|
|||
|
||||
const short T = DK + nsg*SH; // shared memory size per query in (half)
|
||||
|
||||
//threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
|
||||
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t
|
||||
threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*DK); // scratch buffer for attention
|
||||
threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*DK); // same as above but in s4_t
|
||||
threadgroup float * sm = (threadgroup float *) (shmem_f16 + sgitg*SH + 2*C + Q*DK); // scratch buffer for mask
|
||||
threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + sgitg*DV + Q*T); // scratch buffer for the results
|
||||
//threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
|
||||
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t
|
||||
threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*DK); // scratch buffer for attention
|
||||
threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*DK); // same as above but in s4_t
|
||||
threadgroup float * sm = (threadgroup float *) (shmem_f16 + sgitg*SH + 2*C + Q*DK); // scratch buffer for mask
|
||||
threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + 2*sgitg*DV + Q*T); // scratch buffer for the results
|
||||
|
||||
// store the result for all queries in local memory (the O matrix from the paper)
|
||||
o4_t lo[DV4/NL];
|
||||
|
|
@ -4157,7 +4167,7 @@ kernel void kernel_flash_attn_ext_vec(
|
|||
half4, \
|
||||
float, \
|
||||
float, float4, \
|
||||
half4
|
||||
float4
|
||||
|
||||
typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t;
|
||||
|
||||
|
|
|
|||
109
ggml/src/ggml-opencl/kernels/concat.cl
Normal file
109
ggml/src/ggml-opencl/kernels/concat.cl
Normal file
|
|
@ -0,0 +1,109 @@
|
|||
kernel void kernel_concat_f32_contiguous(
|
||||
global const char * p_src0, ulong off_src0,
|
||||
global const char * p_src1, ulong off_src1,
|
||||
global char * p_dst, ulong off_dst,
|
||||
int d_ne00, int d_ne01, int d_ne02, // src0->ne[0..2] for the slice
|
||||
int d_ne10, int d_ne11, int d_ne12, // src1->ne[0..2] for the slice (d_ne1X must match d_ne0X on non-concat axes)
|
||||
int d_ne0, int d_ne1, int d_ne2, // dst->ne[0..2] for the slice
|
||||
int dim
|
||||
) {
|
||||
global const float * src0 = (global const float*)((global char*)p_src0 + off_src0);
|
||||
global const float * src1 = (global const float*)((global char*)p_src1 + off_src1);
|
||||
global float * dst = (global float*)((global char*)p_dst + off_dst);
|
||||
|
||||
int i0 = get_global_id(0); // Index along dst's 0th dimension
|
||||
int i1 = get_global_id(1); // Index along dst's 1st dimension
|
||||
int i2 = get_global_id(2); // Index along dst's 2nd dimension
|
||||
|
||||
if (i0 >= d_ne0 || i1 >= d_ne1 || i2 >= d_ne2) {
|
||||
return;
|
||||
}
|
||||
|
||||
ulong dst_idx = (ulong)i2 * d_ne0 * d_ne1 + (ulong)i1 * d_ne0 + i0;
|
||||
ulong src_idx;
|
||||
|
||||
if (dim == 0) {
|
||||
if (i0 < d_ne00) { // Data from src0
|
||||
src_idx = (ulong)i2 * d_ne00 * d_ne01 + (ulong)i1 * d_ne00 + i0;
|
||||
dst[dst_idx] = src0[src_idx];
|
||||
} else { // Data from src1
|
||||
src_idx = (ulong)i2 * d_ne10 * d_ne11 + (ulong)i1 * d_ne10 + (i0 - d_ne00);
|
||||
dst[dst_idx] = src1[src_idx];
|
||||
}
|
||||
} else if (dim == 1) {
|
||||
if (i1 < d_ne01) { // Data from src0
|
||||
src_idx = (ulong)i2 * d_ne00 * d_ne01 + (ulong)i1 * d_ne00 + i0;
|
||||
dst[dst_idx] = src0[src_idx];
|
||||
} else { // Data from src1
|
||||
src_idx = (ulong)i2 * d_ne10 * d_ne11 + (ulong)(i1 - d_ne01) * d_ne10 + i0;
|
||||
dst[dst_idx] = src1[src_idx];
|
||||
}
|
||||
} else if (dim == 2) {
|
||||
if (i2 < d_ne02) { // Data from src0
|
||||
src_idx = (ulong)i2 * d_ne00 * d_ne01 + (ulong)i1 * d_ne00 + i0;
|
||||
dst[dst_idx] = src0[src_idx];
|
||||
} else { // Data from src1
|
||||
|
||||
src_idx = (ulong)(i2 - d_ne02) * d_ne10 * d_ne11 + (ulong)i1 * d_ne10 + i0;
|
||||
dst[dst_idx] = src1[src_idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_concat_f32_non_contiguous(
|
||||
global const char * p_src0, ulong off_src0,
|
||||
global const char * p_src1, ulong off_src1,
|
||||
global char * p_dst, ulong off_dst,
|
||||
|
||||
long ne00, long ne01, long ne02, long ne03,
|
||||
ulong nb00, ulong nb01, ulong nb02, ulong nb03,
|
||||
|
||||
ulong nb10, ulong nb11, ulong nb12, ulong nb13, // Strides for src1
|
||||
|
||||
long d_ne0, long d_ne1, long d_ne2, long d_ne3,
|
||||
ulong d_nb0, ulong d_nb1, ulong d_nb2, ulong d_nb3,
|
||||
int dim
|
||||
) {
|
||||
global const char * src0_base = p_src0 + off_src0;
|
||||
global const char * src1_base = p_src1 + off_src1;
|
||||
global char * dst_base = p_dst + off_dst;
|
||||
|
||||
long current_i1 = get_global_id(0); // Index for dst_dim_1
|
||||
long current_i2 = get_global_id(1); // Index for dst_dim_2
|
||||
long current_i3 = get_global_id(2); // Index for dst_dim_3
|
||||
|
||||
if (current_i1 >= d_ne1 || current_i2 >= d_ne2 || current_i3 >= d_ne3) {
|
||||
return;
|
||||
}
|
||||
|
||||
global const float * x_val_ptr;
|
||||
global float * y_val_ptr;
|
||||
|
||||
for (long current_i0 = 0; current_i0 < d_ne0; ++current_i0) {
|
||||
bool use_src0;
|
||||
long s_i0 = current_i0, s_i1 = current_i1, s_i2 = current_i2, s_i3 = current_i3;
|
||||
|
||||
if (dim == 0) {
|
||||
use_src0 = (current_i0 < ne00);
|
||||
if (!use_src0) { s_i0 = current_i0 - ne00; }
|
||||
} else if (dim == 1) {
|
||||
use_src0 = (current_i1 < ne01);
|
||||
if (!use_src0) { s_i1 = current_i1 - ne01; }
|
||||
} else if (dim == 2) {
|
||||
use_src0 = (current_i2 < ne02);
|
||||
if (!use_src0) { s_i2 = current_i2 - ne02; }
|
||||
} else { // dim == 3
|
||||
use_src0 = (current_i3 < ne03);
|
||||
if (!use_src0) { s_i3 = current_i3 - ne03; }
|
||||
}
|
||||
|
||||
if (use_src0) {
|
||||
x_val_ptr = (global const float *)(src0_base + (ulong)s_i3*nb03 + (ulong)s_i2*nb02 + (ulong)s_i1*nb01 + (ulong)s_i0*nb00);
|
||||
} else {
|
||||
x_val_ptr = (global const float *)(src1_base + (ulong)s_i3*nb13 + (ulong)s_i2*nb12 + (ulong)s_i1*nb11 + (ulong)s_i0*nb10);
|
||||
}
|
||||
|
||||
y_val_ptr = (global float *)(dst_base + (ulong)current_i3*d_nb3 + (ulong)current_i2*d_nb2 + (ulong)current_i1*d_nb1 + (ulong)current_i0*d_nb0);
|
||||
*y_val_ptr = *x_val_ptr;
|
||||
}
|
||||
}
|
||||
30
ggml/src/ggml-opencl/kernels/pad.cl
Normal file
30
ggml/src/ggml-opencl/kernels/pad.cl
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
kernel void kernel_pad(
|
||||
global const void * src0_ptr,
|
||||
ulong src0_offset,
|
||||
global void * dst_ptr,
|
||||
ulong dst_offset,
|
||||
int s_ne0, int s_ne1, int s_ne2,
|
||||
int d_ne0, int d_ne1, int d_ne2
|
||||
) {
|
||||
global const float * src0 = (global const float *)((global const char *)src0_ptr + src0_offset);
|
||||
global float * dst = (global float *)((global char *)dst_ptr + dst_offset);
|
||||
|
||||
int nidx = get_global_id(0);
|
||||
int idx_d1 = get_group_id(1);
|
||||
int idx_d2 = get_group_id(2);
|
||||
|
||||
if (nidx >= d_ne0) {
|
||||
return;
|
||||
}
|
||||
|
||||
int dst_el_offset = nidx + idx_d1 * d_ne0 + idx_d2 * d_ne0 * d_ne1;
|
||||
|
||||
bool in_src_bounds = (nidx < s_ne0) && (idx_d1 < s_ne1) && (idx_d2 < s_ne2);
|
||||
|
||||
if (in_src_bounds) {
|
||||
int src_el_offset = nidx + idx_d1 * s_ne0 + idx_d2 * s_ne0 * s_ne1;
|
||||
dst[dst_el_offset] = src0[src_el_offset];
|
||||
} else {
|
||||
dst[dst_el_offset] = 0.0f;
|
||||
}
|
||||
}
|
||||
39
ggml/src/ggml-opencl/kernels/repeat.cl
Normal file
39
ggml/src/ggml-opencl/kernels/repeat.cl
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
kernel void kernel_repeat(
|
||||
global const char * src0_data_in,
|
||||
global char * dst_data_in,
|
||||
ulong src0_offset,
|
||||
ulong dst_offset,
|
||||
int src0_ne0, int src0_ne1, int src0_ne2, int src0_ne3,
|
||||
ulong src0_nb0, ulong src0_nb1, ulong src0_nb2, ulong src0_nb3,
|
||||
int dst_ne0, int dst_ne1, int dst_ne2, int dst_ne3,
|
||||
ulong dst_nb0, ulong dst_nb1, ulong dst_nb2, ulong dst_nb3
|
||||
) {
|
||||
global const char * src0_data = src0_data_in + src0_offset;
|
||||
global char * dst_data = dst_data_in + dst_offset;
|
||||
|
||||
const int d3 = get_global_id(2);
|
||||
const int d2 = get_global_id(1);
|
||||
const int d1 = get_global_id(0);
|
||||
|
||||
if (d3 >= dst_ne3 || d2 >= dst_ne2 || d1 >= dst_ne1) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int s3 = d3 % src0_ne3;
|
||||
const int s2 = d2 % src0_ne2;
|
||||
const int s1 = d1 % src0_ne1;
|
||||
|
||||
const global char * p_src0_slice = src0_data + (ulong)s3*src0_nb3 + (ulong)s2*src0_nb2 + (ulong)s1*src0_nb1;
|
||||
global char * p_dst_slice = dst_data + (ulong)d3*dst_nb3 + (ulong)d2*dst_nb2 + (ulong)d1*dst_nb1;
|
||||
|
||||
for (int d0 = 0; d0 < dst_ne0; ++d0) {
|
||||
// Determine source index for dimension 0 based on tiling/broadcasting.
|
||||
const int s0 = d0 % src0_ne0;
|
||||
|
||||
const global char * restrict current_src_el_ptr = p_src0_slice + (ulong)s0*src0_nb0;
|
||||
global char * restrict current_dst_el_ptr = p_dst_slice + (ulong)d0*dst_nb0;
|
||||
for (int k = 0; k < src0_nb0; ++k) {
|
||||
current_dst_el_ptr[k] = current_src_el_ptr[k];
|
||||
}
|
||||
}
|
||||
}
|
||||
63
ggml/src/ggml-opencl/kernels/tanh.cl
Normal file
63
ggml/src/ggml-opencl/kernels/tanh.cl
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
|
||||
#ifdef cl_intel_required_subgroup_size
|
||||
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
|
||||
#define INTEL_GPU 1
|
||||
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
|
||||
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
|
||||
#elif defined(cl_qcom_reqd_sub_group_size)
|
||||
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
|
||||
#define ADRENO_GPU 1
|
||||
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
|
||||
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
|
||||
#endif
|
||||
|
||||
kernel void kernel_tanh_f32_nd(
|
||||
global void * p_src0_base, ulong off_src0_abs,
|
||||
global void * p_dst_base, ulong off_dst_abs,
|
||||
int ne00, int ne01, int ne02, int ne03,
|
||||
ulong nb00, ulong nb01, ulong nb02, ulong nb03,
|
||||
int ne10, int ne11, int ne12, int ne13,
|
||||
ulong nb10, ulong nb11, ulong nb12, ulong nb13
|
||||
) {
|
||||
int i0 = get_global_id(0);
|
||||
int i1 = get_global_id(1);
|
||||
int i2 = get_global_id(2);
|
||||
|
||||
if (i0 < ne10 && i1 < ne11 && i2 < ne12) {
|
||||
for (int i3 = 0; i3 < ne13; ++i3) {
|
||||
ulong src_offset_in_tensor = (ulong)i0*nb00 + (ulong)i1*nb01 + (ulong)i2*nb02 + (ulong)i3*nb03;
|
||||
global const float *src_val_ptr = (global const float *)((global char *)p_src0_base + off_src0_abs + src_offset_in_tensor);
|
||||
|
||||
ulong dst_offset_in_tensor = (ulong)i0*nb10 + (ulong)i1*nb11 + (ulong)i2*nb12 + (ulong)i3*nb13;
|
||||
global float *dst_val_ptr = (global float *)((global char *)p_dst_base + off_dst_abs + dst_offset_in_tensor);
|
||||
|
||||
*dst_val_ptr = tanh(*src_val_ptr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_tanh_f16_nd(
|
||||
global void * p_src0_base, ulong off_src0_abs,
|
||||
global void * p_dst_base, ulong off_dst_abs,
|
||||
int ne00, int ne01, int ne02, int ne03,
|
||||
ulong nb00, ulong nb01, ulong nb02, ulong nb03,
|
||||
int ne10, int ne11, int ne12, int ne13,
|
||||
ulong nb10, ulong nb11, ulong nb12, ulong nb13
|
||||
) {
|
||||
int i0 = get_global_id(0);
|
||||
int i1 = get_global_id(1);
|
||||
int i2 = get_global_id(2);
|
||||
|
||||
if (i0 < ne10 && i1 < ne11 && i2 < ne12) {
|
||||
for (int i3 = 0; i3 < ne13; ++i3) {
|
||||
ulong src_offset_in_tensor = (ulong)i0*nb00 + (ulong)i1*nb01 + (ulong)i2*nb02 + (ulong)i3*nb03;
|
||||
global const half *src_val_ptr = (global const half *)((global char *)p_src0_base + off_src0_abs + src_offset_in_tensor);
|
||||
|
||||
ulong dst_offset_in_tensor = (ulong)i0*nb10 + (ulong)i1*nb11 + (ulong)i2*nb12 + (ulong)i3*nb13;
|
||||
global half *dst_val_ptr = (global half *)((global char *)p_dst_base + off_dst_abs + dst_offset_in_tensor);
|
||||
|
||||
*dst_val_ptr = tanh(*src_val_ptr);
|
||||
}
|
||||
}
|
||||
}
|
||||
48
ggml/src/ggml-opencl/kernels/tsembd.cl
Normal file
48
ggml/src/ggml-opencl/kernels/tsembd.cl
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
kernel void kernel_timestep_embedding(
|
||||
global const void * p_timesteps,
|
||||
ulong off_timesteps,
|
||||
global void * p_dst,
|
||||
ulong off_dst,
|
||||
int dst_nb1_bytes,
|
||||
int logical_dim,
|
||||
int max_period
|
||||
) {
|
||||
int local_i;
|
||||
int local_j;
|
||||
int local_half_dim;
|
||||
float local_timestep_val;
|
||||
float local_freq;
|
||||
float local_arg;
|
||||
global float * local_embed_data_ptr;
|
||||
global const float * local_timesteps_input_ptr;
|
||||
global float * local_dst_output_base_ptr;
|
||||
|
||||
local_timesteps_input_ptr = (global const float *)((global char *)p_timesteps + off_timesteps);
|
||||
local_dst_output_base_ptr = (global float *)((global char *)p_dst + off_dst);
|
||||
|
||||
local_i = get_global_id(1);
|
||||
local_j = get_global_id(0);
|
||||
|
||||
local_half_dim = logical_dim / 2;
|
||||
local_embed_data_ptr = (global float *)((global char *)local_dst_output_base_ptr + local_i * dst_nb1_bytes);
|
||||
|
||||
if (logical_dim % 2 != 0 && local_j == ((logical_dim + 1) / 2)) {
|
||||
local_embed_data_ptr[logical_dim] = 0.0f;
|
||||
}
|
||||
|
||||
if (local_j >= local_half_dim) {
|
||||
return;
|
||||
}
|
||||
|
||||
local_timestep_val = local_timesteps_input_ptr[local_i];
|
||||
|
||||
if (local_half_dim == 0) {
|
||||
local_freq = 1.0f;
|
||||
} else {
|
||||
local_freq = exp(-log((float)max_period) * (float)local_j / (float)local_half_dim);
|
||||
}
|
||||
|
||||
local_arg = local_timestep_val * local_freq;
|
||||
local_embed_data_ptr[local_j] = cos(local_arg);
|
||||
local_embed_data_ptr[local_j + local_half_dim] = sin(local_arg);
|
||||
}
|
||||
121
ggml/src/ggml-opencl/kernels/upscale.cl
Normal file
121
ggml/src/ggml-opencl/kernels/upscale.cl
Normal file
|
|
@ -0,0 +1,121 @@
|
|||
kernel void kernel_upscale(
|
||||
global const void * p_src0,
|
||||
ulong off_src0,
|
||||
global void * p_dst,
|
||||
ulong off_dst,
|
||||
ulong nb00,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
ulong nb03,
|
||||
int ne10,
|
||||
int ne11,
|
||||
int ne12,
|
||||
int ne13,
|
||||
float sf0,
|
||||
float sf1,
|
||||
float sf2,
|
||||
float sf3
|
||||
) {
|
||||
global const char * src_base = (global const char *)p_src0 + off_src0;
|
||||
global float * dst_base = (global float *)((global char *)p_dst + off_dst);
|
||||
|
||||
int index = get_global_id(0);
|
||||
int dst_total_elements = ne10 * ne11 * ne12 * ne13;
|
||||
|
||||
if (index >= dst_total_elements) {
|
||||
return;
|
||||
}
|
||||
|
||||
int i10 = index % ne10;
|
||||
int i11 = (index / ne10) % ne11;
|
||||
int i12 = (index / (ne10 * ne11)) % ne12;
|
||||
int i13 = index / (ne10 * ne11 * ne12);
|
||||
|
||||
int i00 = (int)(i10 / sf0);
|
||||
int i01 = (int)(i11 / sf1);
|
||||
int i02 = (int)(i12 / sf2);
|
||||
int i03 = (int)(i13 / sf3);
|
||||
|
||||
ulong offset_src_element = (ulong)i03 * nb03 + (ulong)i02 * nb02 + (ulong)i01 * nb01 + (ulong)i00 * nb00;
|
||||
global const float * src_element_ptr = (global const float *)(src_base + offset_src_element);
|
||||
|
||||
dst_base[index] = *src_element_ptr;
|
||||
}
|
||||
|
||||
kernel void kernel_upscale_bilinear(
|
||||
global const void * p_src0,
|
||||
ulong off_src0,
|
||||
global void * p_dst,
|
||||
ulong off_dst,
|
||||
ulong nb00,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
ulong nb03,
|
||||
int ne00_src,
|
||||
int ne01_src,
|
||||
int ne10_dst,
|
||||
int ne11_dst,
|
||||
int ne12_dst,
|
||||
int ne13_dst,
|
||||
float sf0,
|
||||
float sf1,
|
||||
float sf2,
|
||||
float sf3
|
||||
) {
|
||||
global const char * src_base = (global const char *)p_src0 + off_src0;
|
||||
global float * dst_base = (global float *)((global char *)p_dst + off_dst);
|
||||
|
||||
int index = get_global_id(0);
|
||||
int dst_total_elements = ne10_dst * ne11_dst * ne12_dst * ne13_dst;
|
||||
|
||||
if (index >= dst_total_elements) {
|
||||
return;
|
||||
}
|
||||
|
||||
int i10_dst = index % ne10_dst;
|
||||
int i11_dst = (index / ne10_dst) % ne11_dst;
|
||||
int i12_dst = (index / (ne10_dst * ne11_dst)) % ne12_dst;
|
||||
int i13_dst = index / (ne10_dst * ne11_dst * ne12_dst);
|
||||
|
||||
int i02_src = (int)(i12_dst / sf2);
|
||||
int i03_src = (int)(i13_dst / sf3);
|
||||
|
||||
const float pixel_offset = 0.5f;
|
||||
|
||||
float y_src_f = ((float)i11_dst + pixel_offset) / sf1 - pixel_offset;
|
||||
long y0_src = (long)floor(y_src_f);
|
||||
long y1_src = y0_src + 1;
|
||||
|
||||
y0_src = max(0L, min(y0_src, (long)ne01_src - 1));
|
||||
y1_src = max(0L, min(y1_src, (long)ne01_src - 1));
|
||||
|
||||
float dy = y_src_f - (float)y0_src;
|
||||
dy = max(0.0f, min(dy, 1.0f));
|
||||
|
||||
float x_src_f = ((float)i10_dst + pixel_offset) / sf0 - pixel_offset;
|
||||
long x0_src = (long)floor(x_src_f);
|
||||
long x1_src = x0_src + 1;
|
||||
|
||||
x0_src = max(0L, min(x0_src, (long)ne00_src - 1));
|
||||
x1_src = max(0L, min(x1_src, (long)ne00_src - 1));
|
||||
|
||||
float dx = x_src_f - (float)x0_src;
|
||||
dx = max(0.0f, min(dx, 1.0f));
|
||||
|
||||
global const float * p_a = (global const float *)(src_base + (ulong)x0_src * nb00 + (ulong)y0_src * nb01 + (ulong)i02_src * nb02 + (ulong)i03_src * nb03);
|
||||
global const float * p_b = (global const float *)(src_base + (ulong)x1_src * nb00 + (ulong)y0_src * nb01 + (ulong)i02_src * nb02 + (ulong)i03_src * nb03);
|
||||
global const float * p_c = (global const float *)(src_base + (ulong)x0_src * nb00 + (ulong)y1_src * nb01 + (ulong)i02_src * nb02 + (ulong)i03_src * nb03);
|
||||
global const float * p_d = (global const float *)(src_base + (ulong)x1_src * nb00 + (ulong)y1_src * nb01 + (ulong)i02_src * nb02 + (ulong)i03_src * nb03);
|
||||
|
||||
const float val_a = *p_a;
|
||||
const float val_b = *p_b;
|
||||
const float val_c = *p_c;
|
||||
const float val_d = *p_d;
|
||||
|
||||
float result = val_a * (1.0f - dx) * (1.0f - dy) +
|
||||
val_b * dx * (1.0f - dy) +
|
||||
val_c * (1.0f - dx) * dy +
|
||||
val_d * dx * dy;
|
||||
|
||||
dst_base[index] = result;
|
||||
}
|
||||
|
|
@ -412,6 +412,7 @@ struct vk_device_struct {
|
|||
vk_pipeline pipeline_count_equal_i32;
|
||||
vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
|
||||
vk_pipeline pipeline_timestep_embedding_f32;
|
||||
vk_pipeline pipeline_conv_transpose_1d_f32;
|
||||
vk_pipeline pipeline_pool2d_f32;
|
||||
vk_pipeline pipeline_rwkv_wkv6_f32;
|
||||
vk_pipeline pipeline_rwkv_wkv7_f32;
|
||||
|
|
@ -460,7 +461,7 @@ struct vk_device_struct {
|
|||
// for GGML_VK_PERF_LOGGER
|
||||
std::unique_ptr<vk_perf_logger> perf_logger;
|
||||
vk::QueryPool query_pool;
|
||||
uint32_t num_queries;
|
||||
int32_t num_queries;
|
||||
|
||||
~vk_device_struct() {
|
||||
VK_LOG_DEBUG("destroy device " << name);
|
||||
|
|
@ -722,6 +723,21 @@ struct vk_op_timestep_embedding_push_constants {
|
|||
uint32_t max_period;
|
||||
};
|
||||
|
||||
struct vk_op_conv_transpose_1d_push_constants {
|
||||
uint32_t Cout;
|
||||
uint32_t Cin;
|
||||
uint32_t K;
|
||||
uint32_t L;
|
||||
uint32_t KL;
|
||||
|
||||
uint32_t nb01;
|
||||
uint32_t nb02;
|
||||
uint32_t nb11;
|
||||
uint32_t nb1;
|
||||
|
||||
int32_t s0;
|
||||
};
|
||||
|
||||
struct vk_op_pool2d_push_constants {
|
||||
uint32_t IW; uint32_t IH;
|
||||
uint32_t OW; uint32_t OH;
|
||||
|
|
@ -2742,6 +2758,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_conv_transpose_1d_f32, "conv_transpose_1d_f32", conv_transpose_1d_f32_len, conv_transpose_1d_f32_data, "main", 3, sizeof(vk_op_conv_transpose_1d_push_constants), {1, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_pool2d_f32, "pool2d_f32", pool2d_f32_len, pool2d_f32_data, "main", 2, sizeof(vk_op_pool2d_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
|
||||
|
|
@ -6416,6 +6434,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|||
return ctx->device->pipeline_timestep_embedding_f32;
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_CONV_TRANSPOSE_1D:
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_conv_transpose_1d_f32;
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_POOL_2D:
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_pool2d_f32;
|
||||
|
|
@ -6750,6 +6773,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|||
uint32_t half_ceil = (dim + 1) / 2;
|
||||
elements = { half_ceil, (uint32_t)src0->ne[0], 1 };
|
||||
} break;
|
||||
case GGML_OP_CONV_TRANSPOSE_1D:
|
||||
{
|
||||
elements = {uint32_t(src0->ne[1]), 1, 1}; // parallelize in {Cout, 1, 1}
|
||||
} break;
|
||||
case GGML_OP_POOL_2D:
|
||||
{
|
||||
const uint32_t N = dst->ne[3];
|
||||
|
|
@ -7553,6 +7580,37 @@ static void ggml_vk_timestep_embedding(ggml_backend_vk_context * ctx, vk_context
|
|||
}, dryrun);
|
||||
}
|
||||
|
||||
static void ggml_vk_conv_transpose_1d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
||||
// src0: (K, Cout, Cin, 1) -- kernel
|
||||
// src1: (L, Cin, 1, 1) -- input
|
||||
// dst: (*, Cout, 1, 1)
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
GGML_ASSERT(nb00 == sizeof(float));
|
||||
GGML_ASSERT(nb10 == sizeof(float));
|
||||
|
||||
const int32_t s0 = dst->op_params[0];
|
||||
|
||||
vk_op_conv_transpose_1d_push_constants p{};
|
||||
p.Cout = static_cast<uint32_t>(ne01);
|
||||
p.Cin = static_cast<uint32_t>(ne02);
|
||||
p.K = static_cast<uint32_t>(ne00);
|
||||
p.L = static_cast<uint32_t>(ne10);
|
||||
p.KL = static_cast<uint32_t>(ne0);
|
||||
p.nb01 = static_cast<uint32_t>(nb01 / nb00);
|
||||
p.nb02 = static_cast<uint32_t>(nb02 / nb00);
|
||||
p.nb11 = static_cast<uint32_t>(nb11 / nb10);
|
||||
p.nb1 = static_cast<uint32_t>(nb1 / nb0);
|
||||
p.s0 = static_cast<uint32_t>(s0);
|
||||
|
||||
ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_TRANSPOSE_1D, std::move(p), dryrun);
|
||||
}
|
||||
|
||||
static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
||||
uint32_t op = static_cast<uint32_t>(dst->op_params[0]);
|
||||
const int32_t k1 = dst->op_params[1];
|
||||
|
|
@ -8624,6 +8682,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|||
case GGML_OP_COUNT_EQUAL:
|
||||
case GGML_OP_IM2COL:
|
||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||
case GGML_OP_CONV_TRANSPOSE_1D:
|
||||
case GGML_OP_POOL_2D:
|
||||
case GGML_OP_CONV_2D_DW:
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
|
|
@ -8688,6 +8747,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|||
case GGML_OP_COUNT_EQUAL:
|
||||
case GGML_OP_IM2COL:
|
||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||
case GGML_OP_CONV_TRANSPOSE_1D:
|
||||
case GGML_OP_POOL_2D:
|
||||
case GGML_OP_CONV_2D_DW:
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
|
|
@ -8859,6 +8919,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|||
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||
ggml_vk_timestep_embedding(ctx, compute_ctx, src0, node, dryrun);
|
||||
|
||||
break;
|
||||
case GGML_OP_CONV_TRANSPOSE_1D:
|
||||
ggml_vk_conv_transpose_1d(ctx, compute_ctx, src0, src1, node, dryrun);
|
||||
|
||||
break;
|
||||
case GGML_OP_POOL_2D:
|
||||
ggml_vk_pool_2d(ctx, compute_ctx, src0, node, dryrun);
|
||||
|
|
@ -8987,6 +9051,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
|
|||
case GGML_OP_COUNT_EQUAL:
|
||||
case GGML_OP_IM2COL:
|
||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||
case GGML_OP_CONV_TRANSPOSE_1D:
|
||||
case GGML_OP_POOL_2D:
|
||||
case GGML_OP_CONV_2D_DW:
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
|
|
@ -9537,8 +9602,8 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|||
if (ctx->device->query_pool) {
|
||||
ctx->device->device.destroyQueryPool(ctx->device->query_pool);
|
||||
}
|
||||
VkQueryPoolCreateInfo query_create_info = { VK_STRUCTURE_TYPE_QUERY_POOL_CREATE_INFO };
|
||||
query_create_info.queryType = VK_QUERY_TYPE_TIMESTAMP;
|
||||
vk::QueryPoolCreateInfo query_create_info;
|
||||
query_create_info.queryType = vk::QueryType::eTimestamp;
|
||||
query_create_info.queryCount = cgraph->n_nodes + 100;
|
||||
ctx->device->query_pool = ctx->device->device.createQueryPool(query_create_info);
|
||||
ctx->device->num_queries = query_create_info.queryCount;
|
||||
|
|
@ -9624,7 +9689,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|||
|
||||
// Get the results and pass them to the logger
|
||||
std::vector<uint64_t> timestamps(cgraph->n_nodes + 1);
|
||||
ctx->device->device.getQueryPoolResults(ctx->device->query_pool, 0, cgraph->n_nodes + 1, (cgraph->n_nodes + 1)*sizeof(uint64_t), timestamps.data(), sizeof(uint64_t), vk::QueryResultFlagBits::e64 | vk::QueryResultFlagBits::eWait);
|
||||
VK_CHECK(ctx->device->device.getQueryPoolResults(ctx->device->query_pool, 0, cgraph->n_nodes + 1, (cgraph->n_nodes + 1)*sizeof(uint64_t), timestamps.data(), sizeof(uint64_t), vk::QueryResultFlagBits::e64 | vk::QueryResultFlagBits::eWait), "get timestamp results");
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
if (!ggml_vk_is_empty(cgraph->nodes[i])) {
|
||||
ctx->device->perf_logger->log_timing(cgraph->nodes[i], uint64_t((timestamps[i+1] - timestamps[i]) * ctx->device->properties.limits.timestampPeriod));
|
||||
|
|
@ -10048,6 +10113,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|||
case GGML_OP_LEAKY_RELU:
|
||||
case GGML_OP_OPT_STEP_ADAMW:
|
||||
return true;
|
||||
case GGML_OP_CONV_TRANSPOSE_1D:
|
||||
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
|
@ -10539,6 +10606,11 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
|||
const int32_t dim = tensor->op_params[0];
|
||||
const int32_t max_period = tensor->op_params[1];
|
||||
tensor_clone = ggml_timestep_embedding(ggml_ctx, src_clone[0], dim, max_period);
|
||||
} else if (tensor->op == GGML_OP_CONV_TRANSPOSE_1D){
|
||||
const int32_t s0 = tensor->op_params[0];
|
||||
const int32_t p0 = tensor->op_params[1];
|
||||
const int32_t d0 = tensor->op_params[2];
|
||||
tensor_clone = ggml_conv_transpose_1d(ggml_ctx, src_clone[0], src_clone[1], s0, p0, d0);
|
||||
} else if (tensor->op == GGML_OP_POOL_2D) {
|
||||
enum ggml_op_pool op = static_cast<ggml_op_pool>(tensor->op_params[0]);
|
||||
const int32_t k0 = tensor->op_params[1];
|
||||
|
|
|
|||
98
ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp
Normal file
98
ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp
Normal file
|
|
@ -0,0 +1,98 @@
|
|||
#version 450
|
||||
|
||||
#include "types.comp"
|
||||
|
||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; // src0 - kernel: [K, Cout, Cin]
|
||||
layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; // src1 - input: [L, Cin]
|
||||
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; // dst - result [KL, Cout]
|
||||
|
||||
layout(local_size_x = 128 , local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (push_constant) uniform parameter {
|
||||
uint32_t Cout;
|
||||
uint32_t Cin;
|
||||
uint32_t K;
|
||||
uint32_t L;
|
||||
uint32_t KL;
|
||||
|
||||
uint32_t nb01;
|
||||
uint32_t nb02;
|
||||
uint32_t nb11;
|
||||
uint32_t nb1;
|
||||
|
||||
int32_t s0;
|
||||
} p;
|
||||
|
||||
|
||||
uint32_t Cout_idx = gl_WorkGroupID.x;
|
||||
const uint32_t bs = gl_WorkGroupSize.x;
|
||||
uint32_t tid = gl_LocalInvocationID.x;
|
||||
// Code is more straightforward if we assume it is bs*s0+K instead of (bs-1)*s0+K.
|
||||
uint32_t tmp_len = bs*p.s0+p.K;
|
||||
shared D_TYPE tmp[4096];
|
||||
|
||||
uint splitWork(uint workSize){
|
||||
return (bs + workSize -1) / bs;
|
||||
}
|
||||
|
||||
void main(){
|
||||
for(uint32_t i = 0; i < splitWork(tmp_len); i++){
|
||||
uint32_t idx = i*bs+tid;
|
||||
if(idx < tmp_len){
|
||||
tmp[idx] = 0.0;
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t L_blocks = splitWork(p.L);
|
||||
for(uint32_t L_block_id = 0; L_block_id < L_blocks; L_block_id++){
|
||||
if(L_block_id > 0){
|
||||
barrier();
|
||||
// Shift values in tmp to the current processing window
|
||||
for(int i = 0; i < splitWork(tmp_len); i++){
|
||||
uint32_t idx = i*bs+tid;
|
||||
if(idx >= bs*p.s0 && idx < tmp_len){
|
||||
tmp[idx-bs*p.s0] = tmp[idx];
|
||||
tmp[idx] = 0.0;
|
||||
}else if(idx >= p.K && idx < bs*p.s0){
|
||||
tmp[idx] = 0.0;
|
||||
}
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
|
||||
// Save contributions of the block to tmp
|
||||
uint32_t L_idx = L_block_id*bs + tid;
|
||||
for(uint32_t K_idx = 0; K_idx < p.K; K_idx++){
|
||||
D_TYPE dp = 0.0;
|
||||
for(uint32_t Cin_idx = 0; Cin_idx < p.Cin; Cin_idx++){
|
||||
A_TYPE elemKrn = data_a[K_idx + Cout_idx * p.nb01 + Cin_idx * p.nb02];
|
||||
if(L_idx < p.L){
|
||||
B_TYPE elemInp = data_b[L_idx + Cin_idx*p.nb11];
|
||||
dp = fma(elemKrn, elemInp, dp);
|
||||
}
|
||||
}
|
||||
tmp[tid*p.s0 + K_idx] += dp;
|
||||
barrier();
|
||||
}
|
||||
|
||||
// Save the computed values except the last block that can have different size
|
||||
uint32_t KLb_idx = L_block_id*bs*p.s0;
|
||||
if(L_block_id < L_blocks-1){
|
||||
for(uint32_t s0_idx = 0; s0_idx < p.s0; s0_idx++){
|
||||
uint32_t sh_idx = p.s0*tid+s0_idx;
|
||||
uint32_t KL_idx = KLb_idx+sh_idx;
|
||||
if(KL_idx < p.KL){
|
||||
data_d[KL_idx + Cout_idx*p.nb1] = tmp[sh_idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for(uint32_t i = 0; i < splitWork(tmp_len); i++){
|
||||
uint32_t idx = i*bs+tid;
|
||||
uint32_t KL_idx = (L_blocks-1)*bs*p.s0+idx;
|
||||
if(KL_idx < p.KL){
|
||||
data_d[KL_idx + Cout_idx*p.nb1] = tmp[idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -636,6 +636,8 @@ void process_shaders() {
|
|||
|
||||
string_to_spv("timestep_embedding_f32", "timestep_embedding.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
|
||||
string_to_spv("conv_transpose_1d_f32", "conv_transpose_1d.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
|
||||
string_to_spv("pool2d_f32", "pool2d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
|
||||
string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
||||
|
|
|
|||
|
|
@ -429,22 +429,54 @@ const llama_kv_cache * llama_context::get_kv_self() const {
|
|||
return kv_self;
|
||||
}
|
||||
|
||||
bool llama_context::kv_self_update() {
|
||||
void llama_context::kv_self_defrag_sched() {
|
||||
if (!memory) {
|
||||
return;
|
||||
}
|
||||
|
||||
memory_force_optimize = true;
|
||||
}
|
||||
|
||||
bool llama_context::kv_self_update(bool optimize) {
|
||||
if (!memory) {
|
||||
return false;
|
||||
}
|
||||
|
||||
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
||||
|
||||
if (!kv_self->update(*this)) {
|
||||
// no updates have been performed
|
||||
return false;
|
||||
{
|
||||
// TODO: remove in the future
|
||||
optimize |= memory_force_optimize;
|
||||
memory_force_optimize = false;
|
||||
|
||||
const auto kv_state = kv_self->init_update(this, optimize);
|
||||
switch (kv_state->get_status()) {
|
||||
case LLAMA_MEMORY_STATUS_SUCCESS:
|
||||
{
|
||||
// noop
|
||||
} break;
|
||||
case LLAMA_MEMORY_STATUS_NO_UPDATE:
|
||||
{
|
||||
// no updates need to be performed
|
||||
return false;
|
||||
}
|
||||
case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
|
||||
case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
|
||||
{
|
||||
LLAMA_LOG_ERROR("%s: failed to prepare memory update\n", __func__);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (!kv_state->apply()) {
|
||||
LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__);
|
||||
}
|
||||
}
|
||||
|
||||
// if the KV cache did any computation, we have to reserve a new worst-case graph
|
||||
const auto kv_state = kv_self->init_full();
|
||||
if (!kv_state) {
|
||||
throw std::runtime_error("failed to initialize KV cache");
|
||||
throw std::runtime_error("failed to initialize memory state");
|
||||
}
|
||||
|
||||
const uint32_t n_seqs = cparams.n_seq_max;
|
||||
|
|
@ -452,7 +484,7 @@ bool llama_context::kv_self_update() {
|
|||
|
||||
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get());
|
||||
if (!gf) {
|
||||
LLAMA_LOG_ERROR("%s: failed to reserve graph after the KV cache update\n", __func__);
|
||||
LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__);
|
||||
}
|
||||
|
||||
return true;
|
||||
|
|
@ -940,13 +972,13 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|||
n_outputs_all = 1;
|
||||
}
|
||||
|
||||
bool did_optimize = false;
|
||||
|
||||
// handle any pending defrags/shifts
|
||||
kv_self_update();
|
||||
kv_self_update(false);
|
||||
|
||||
llama_memory_state_ptr kv_state;
|
||||
|
||||
bool did_defrag = false;
|
||||
|
||||
while (true) {
|
||||
kv_state = kv_self->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
|
||||
if (!kv_state) {
|
||||
|
|
@ -957,25 +989,32 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|||
case LLAMA_MEMORY_STATUS_SUCCESS:
|
||||
{
|
||||
} break;
|
||||
case LLAMA_MEMORY_STATUS_NO_UPDATE:
|
||||
{
|
||||
LLAMA_LOG_ERROR("%s: unexpected memory state status: %d\n", __func__, kv_state->get_status());
|
||||
|
||||
return -2;
|
||||
}
|
||||
case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
|
||||
{
|
||||
if (!did_defrag) {
|
||||
did_defrag = true;
|
||||
if (!did_optimize) {
|
||||
did_optimize = true;
|
||||
|
||||
kv_self->defrag_sched(-1.0f);
|
||||
if (kv_self_update()) {
|
||||
LLAMA_LOG_DEBUG("%s: failed to init batch of size %d, retrying after defrag\n", __func__, batch.n_tokens);
|
||||
if (kv_self_update(true)) {
|
||||
LLAMA_LOG_DEBUG("%s: retrying batch size %d after cache optimization\n", __func__, batch.n_tokens);
|
||||
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
LLAMA_LOG_WARN("%s: failed to find KV cache slot for batch of size %d\n", __func__, batch.n_tokens);
|
||||
LLAMA_LOG_WARN("%s: failed to find a memory slot for batch of size %d\n", __func__, batch.n_tokens);
|
||||
|
||||
return 1;
|
||||
}
|
||||
case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
|
||||
{
|
||||
LLAMA_LOG_ERROR("%s: compute failed while preparing batch of size %d\n", __func__, batch.n_tokens);
|
||||
|
||||
return -2;
|
||||
}
|
||||
}
|
||||
|
|
@ -1189,11 +1228,6 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|||
// wait for the computation to finish (automatically done when obtaining the model output)
|
||||
//synchronize();
|
||||
|
||||
// decide if we need to defrag the kv cache
|
||||
if (cparams.defrag_thold > 0.0f) {
|
||||
kv_self->defrag_sched(cparams.defrag_thold);
|
||||
}
|
||||
|
||||
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
|
||||
// overlap with device computation.
|
||||
ggml_backend_sched_reset(sched.get());
|
||||
|
|
@ -2283,7 +2317,7 @@ llama_kv_cache * llama_get_kv_self(llama_context * ctx) {
|
|||
|
||||
// deprecated
|
||||
void llama_kv_self_update(llama_context * ctx) {
|
||||
ctx->kv_self_update();
|
||||
ctx->kv_self_update(false);
|
||||
}
|
||||
|
||||
enum llama_pooling_type llama_pooling_type(const llama_context * ctx) {
|
||||
|
|
@ -2538,13 +2572,8 @@ llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
|
|||
|
||||
// deprecated
|
||||
void llama_kv_self_defrag(llama_context * ctx) {
|
||||
auto * kv = ctx->get_kv_self();
|
||||
if (!kv) {
|
||||
return;
|
||||
}
|
||||
|
||||
// force defrag
|
||||
kv->defrag_sched(-1.0f);
|
||||
ctx->kv_self_defrag_sched();
|
||||
}
|
||||
|
||||
bool llama_kv_self_can_shift(const llama_context * ctx) {
|
||||
|
|
|
|||
|
|
@ -52,7 +52,8 @@ struct llama_context {
|
|||
|
||||
// return true of the KV cache was updated
|
||||
// TODO: remove
|
||||
bool kv_self_update();
|
||||
bool kv_self_update(bool optimize);
|
||||
void kv_self_defrag_sched();
|
||||
|
||||
enum llama_pooling_type pooling_type() const;
|
||||
|
||||
|
|
@ -231,6 +232,9 @@ private:
|
|||
|
||||
std::unique_ptr<llama_memory_i> memory;
|
||||
|
||||
// TODO: temporary, until the llama_kv_self_defrag() API is removed
|
||||
bool memory_force_optimize = false;
|
||||
|
||||
// decode output (2-dimensional array: [n_outputs][n_vocab])
|
||||
size_t logits_size = 0; // capacity (of floats) for logits
|
||||
float * logits = nullptr;
|
||||
|
|
|
|||
|
|
@ -769,9 +769,8 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|||
cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens);
|
||||
|
||||
if (weight_before_ffn) {
|
||||
// TODO: this is a workaround as we don't yet have a repeat op that takes custom dim (ggml_repeat_4d)
|
||||
ggml_tensor * repeated = ggml_new_tensor_3d(ctx0, cur->type, n_embd, n_expert_used, n_tokens);
|
||||
repeated = ggml_repeat(ctx0, cur, repeated); // [n_embd, n_expert_used, n_tokens]
|
||||
// repeat cur to [n_embd, n_expert_used, n_tokens]
|
||||
ggml_tensor * repeated = ggml_repeat_4d(ctx0, cur, n_embd, n_expert_used, n_tokens, 1);
|
||||
cur = ggml_mul(ctx0, repeated, weights);
|
||||
cb(cur, "ffn_moe_weighted", il);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
#include "llama-kv-cache-recurrent.h"
|
||||
|
||||
#include "llama-impl.h"
|
||||
#include "llama-io.h"
|
||||
#include "llama-batch.h"
|
||||
#include "llama-model.h"
|
||||
|
||||
|
|
@ -386,6 +387,13 @@ llama_memory_state_ptr llama_kv_cache_recurrent::init_full() {
|
|||
return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_SUCCESS, this);
|
||||
}
|
||||
|
||||
llama_memory_state_ptr llama_kv_cache_recurrent::init_update(llama_context * lctx, bool optimize) {
|
||||
GGML_UNUSED(lctx);
|
||||
GGML_UNUSED(optimize);
|
||||
|
||||
return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_NO_UPDATE);
|
||||
}
|
||||
|
||||
bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatches) {
|
||||
// simply remember the full state because it is very small for this type of cache
|
||||
// TODO: optimize
|
||||
|
|
@ -419,17 +427,6 @@ bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatche
|
|||
return success;
|
||||
}
|
||||
|
||||
bool llama_kv_cache_recurrent::update(llama_context & lctx) {
|
||||
GGML_UNUSED(lctx);
|
||||
// noop
|
||||
return false;
|
||||
}
|
||||
|
||||
void llama_kv_cache_recurrent::defrag_sched(float thold) {
|
||||
GGML_UNUSED(thold);
|
||||
// noop
|
||||
}
|
||||
|
||||
bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
|
||||
const uint32_t n_tokens = ubatch.n_tokens;
|
||||
const uint32_t n_seqs = ubatch.n_seqs;
|
||||
|
|
|
|||
|
|
@ -52,9 +52,7 @@ public:
|
|||
|
||||
llama_memory_state_ptr init_full() override;
|
||||
|
||||
bool update(llama_context & lctx) override;
|
||||
|
||||
void defrag_sched(float thold) override;
|
||||
llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
|
||||
|
||||
bool prepare(const std::vector<llama_ubatch> & ubatches);
|
||||
|
||||
|
|
|
|||
|
|
@ -123,26 +123,16 @@ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch
|
|||
|
||||
assert(heads_base.size() == heads_swa.size());
|
||||
|
||||
return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_SUCCESS,
|
||||
return std::make_unique<llama_kv_cache_unified_iswa_state>(
|
||||
this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches));
|
||||
}
|
||||
|
||||
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_full() {
|
||||
return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_SUCCESS, this);
|
||||
return std::make_unique<llama_kv_cache_unified_iswa_state>(this);
|
||||
}
|
||||
|
||||
bool llama_kv_cache_unified_iswa::update(llama_context & lctx) {
|
||||
bool res = false;
|
||||
|
||||
res = res | kv_base->update(lctx);
|
||||
res = res | kv_swa ->update(lctx);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified_iswa::defrag_sched(float thold) {
|
||||
kv_base->defrag_sched(thold);
|
||||
kv_swa ->defrag_sched(thold);
|
||||
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_update(llama_context * lctx, bool optimize) {
|
||||
return std::make_unique<llama_kv_cache_unified_iswa_state>(this, lctx, optimize);
|
||||
}
|
||||
|
||||
bool llama_kv_cache_unified_iswa::get_can_shift() const {
|
||||
|
|
@ -174,26 +164,38 @@ llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const {
|
|||
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(llama_memory_status status) : status(status) {}
|
||||
|
||||
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
|
||||
llama_memory_status status,
|
||||
llama_kv_cache_unified_iswa * kv) : status(status) {
|
||||
state_base.reset(new llama_kv_cache_unified_state(status, kv->get_base()));
|
||||
state_swa .reset(new llama_kv_cache_unified_state(status, kv->get_swa ()));
|
||||
llama_kv_cache_unified_iswa * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS) {
|
||||
state_base = kv->get_base()->init_full();
|
||||
state_swa = kv->get_swa ()->init_full();
|
||||
|
||||
status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status());
|
||||
}
|
||||
|
||||
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
|
||||
llama_kv_cache_unified_iswa * kv,
|
||||
llama_context * lctx,
|
||||
bool optimize) : status(LLAMA_MEMORY_STATUS_SUCCESS) {
|
||||
state_base = kv->get_base()->init_update(lctx, optimize);
|
||||
state_swa = kv->get_swa ()->init_update(lctx, optimize);
|
||||
|
||||
status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status());
|
||||
}
|
||||
|
||||
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
|
||||
llama_memory_status status,
|
||||
llama_kv_cache_unified_iswa * kv,
|
||||
llama_sbatch sbatch,
|
||||
std::vector<uint32_t> heads_base,
|
||||
std::vector<uint32_t> heads_swa,
|
||||
std::vector<llama_ubatch> ubatches)
|
||||
: status(status),
|
||||
sbatch(std::move(sbatch)),
|
||||
ubatches(std::move(ubatches)) {
|
||||
// note: here we copy the ubatches. not sure if this is ideal
|
||||
state_base.reset(new llama_kv_cache_unified_state(status, kv->get_base(), {}, std::move(heads_base), this->ubatches));
|
||||
state_swa .reset(new llama_kv_cache_unified_state(status, kv->get_swa (), {}, std::move(heads_swa), this->ubatches));
|
||||
}
|
||||
: status(LLAMA_MEMORY_STATUS_SUCCESS),
|
||||
sbatch(std::move(sbatch)),
|
||||
ubatches(std::move(ubatches)) {
|
||||
// note: here we copy the ubatches. not sure if this is ideal
|
||||
state_base.reset(new llama_kv_cache_unified_state(kv->get_base(), {}, std::move(heads_base), this->ubatches));
|
||||
state_swa .reset(new llama_kv_cache_unified_state(kv->get_swa (), {}, std::move(heads_swa), this->ubatches));
|
||||
|
||||
status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status());
|
||||
}
|
||||
|
||||
llama_kv_cache_unified_iswa_state:: ~llama_kv_cache_unified_iswa_state() = default;
|
||||
|
||||
|
|
@ -233,17 +235,18 @@ llama_memory_status llama_kv_cache_unified_iswa_state::get_status() const {
|
|||
|
||||
const llama_ubatch & llama_kv_cache_unified_iswa_state::get_ubatch() const {
|
||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||
|
||||
return ubatches[i_next];
|
||||
}
|
||||
|
||||
const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_base() const {
|
||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||
|
||||
return state_base.get();
|
||||
return static_cast<const llama_kv_cache_unified_state *>(state_base.get());
|
||||
}
|
||||
|
||||
const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_swa() const {
|
||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||
|
||||
return state_swa.get();
|
||||
return static_cast<const llama_kv_cache_unified_state *>(state_swa.get());
|
||||
}
|
||||
|
|
|
|||
|
|
@ -54,9 +54,7 @@ public:
|
|||
|
||||
llama_memory_state_ptr init_full() override;
|
||||
|
||||
bool update(llama_context & lctx) override;
|
||||
|
||||
void defrag_sched(float thold) override;
|
||||
llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
|
||||
|
||||
bool get_can_shift() const override;
|
||||
|
||||
|
|
@ -86,12 +84,16 @@ public:
|
|||
|
||||
// used to create a full-cache state
|
||||
llama_kv_cache_unified_iswa_state(
|
||||
llama_memory_status status,
|
||||
llama_kv_cache_unified_iswa * kv);
|
||||
|
||||
// used to create an update state
|
||||
llama_kv_cache_unified_iswa_state(
|
||||
llama_kv_cache_unified_iswa * kv,
|
||||
llama_context * lctx,
|
||||
bool optimize);
|
||||
|
||||
// used to create a state from a batch
|
||||
llama_kv_cache_unified_iswa_state(
|
||||
llama_memory_status status,
|
||||
llama_kv_cache_unified_iswa * kv,
|
||||
llama_sbatch sbatch,
|
||||
std::vector<uint32_t> heads_base,
|
||||
|
|
@ -120,7 +122,7 @@ public:
|
|||
const llama_kv_cache_unified_state * get_swa() const;
|
||||
|
||||
private:
|
||||
const llama_memory_status status;
|
||||
llama_memory_status status;
|
||||
|
||||
//llama_kv_cache_unified_iswa * kv;
|
||||
|
||||
|
|
@ -131,6 +133,6 @@ private:
|
|||
|
||||
std::vector<llama_ubatch> ubatches;
|
||||
|
||||
std::unique_ptr<llama_kv_cache_unified_state> state_base;
|
||||
std::unique_ptr<llama_kv_cache_unified_state> state_swa;
|
||||
llama_memory_state_ptr state_base;
|
||||
llama_memory_state_ptr state_swa;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
#include "llama-kv-cache-unified.h"
|
||||
|
||||
#include "llama-impl.h"
|
||||
#include "llama-io.h"
|
||||
#include "llama-model.h"
|
||||
#include "llama-context.h"
|
||||
|
||||
|
|
@ -149,12 +150,27 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
|
|||
p1 = std::numeric_limits<llama_pos>::max();
|
||||
}
|
||||
|
||||
for (uint32_t i = 0; i < cells.size(); ++i) {
|
||||
if (!cells.pos_in(i, p0, p1)) {
|
||||
continue;
|
||||
}
|
||||
if (seq_id >= 0) {
|
||||
for (uint32_t i = 0; i < cells.size(); ++i) {
|
||||
if (!cells.pos_in(i, p0, p1)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (cells.seq_has(i, seq_id) && cells.seq_rm(i, seq_id)) {
|
||||
if (new_head == cells.size()) {
|
||||
new_head = i;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// match any sequence
|
||||
for (uint32_t i = 0; i < cells.size(); ++i) {
|
||||
if (!cells.pos_in(i, p0, p1)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
cells.rm(i);
|
||||
|
||||
if (cells.seq_has(i, seq_id) && cells.seq_rm(i, seq_id)) {
|
||||
if (new_head == cells.size()) {
|
||||
new_head = i;
|
||||
}
|
||||
|
|
@ -305,16 +321,49 @@ llama_memory_state_ptr llama_kv_cache_unified::init_batch(
|
|||
return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
||||
}
|
||||
|
||||
return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_SUCCESS,
|
||||
return std::make_unique<llama_kv_cache_unified_state>(
|
||||
this, std::move(sbatch), std::move(heads), std::move(ubatches));
|
||||
}
|
||||
|
||||
llama_memory_state_ptr llama_kv_cache_unified::init_full() {
|
||||
return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_SUCCESS, this);
|
||||
return std::make_unique<llama_kv_cache_unified_state>(this);
|
||||
}
|
||||
|
||||
std::vector<uint32_t> llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
|
||||
std::vector<uint32_t> res;
|
||||
llama_memory_state_ptr llama_kv_cache_unified::init_update(llama_context * lctx, bool optimize) {
|
||||
bool do_shift = get_has_shift();
|
||||
|
||||
defrag_info dinfo;
|
||||
|
||||
// see if we need to defrag
|
||||
{
|
||||
bool do_defrag = optimize;
|
||||
|
||||
const auto thold = lctx->get_cparams().defrag_thold;
|
||||
|
||||
if (!do_defrag && thold > 0.0f) {
|
||||
const auto n_kv = cells.used_max_p1();
|
||||
|
||||
// - do not defrag small contexts (i.e. < 2048 tokens)
|
||||
// - count the padding towards the number of used tokens
|
||||
const float fragmentation = n_kv >= 2048 ? std::max(0.0f, 1.0f - (float(cells.get_used() + n_pad)/n_kv)) : 0.0f;
|
||||
|
||||
if (fragmentation > thold) {
|
||||
LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation);
|
||||
|
||||
do_defrag = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (do_defrag) {
|
||||
dinfo = defrag_prepare(lctx->graph_max_nodes());
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_unique<llama_kv_cache_unified_state>(this, lctx, do_shift, std::move(dinfo));
|
||||
}
|
||||
|
||||
llama_kv_cache_unified::ubatch_heads llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
|
||||
llama_kv_cache_unified::ubatch_heads res;
|
||||
|
||||
struct state {
|
||||
uint32_t head_old; // old position of the head, before placing the ubatch
|
||||
|
|
@ -359,12 +408,12 @@ std::vector<uint32_t> llama_kv_cache_unified::prepare(const std::vector<llama_ub
|
|||
return res;
|
||||
}
|
||||
|
||||
bool llama_kv_cache_unified::update(llama_context & lctx) {
|
||||
bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const defrag_info & dinfo) {
|
||||
bool updated = false;
|
||||
|
||||
auto * sched = lctx.get_sched();
|
||||
auto * sched = lctx->get_sched();
|
||||
|
||||
if (cells.get_has_shift()) {
|
||||
if (do_shift) {
|
||||
if (!get_can_shift()) {
|
||||
printf("\nWARNING: The current KV cache / model configuration does not support K-shift");
|
||||
} else {
|
||||
|
|
@ -375,9 +424,9 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
|
|||
if (hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
|
||||
ggml_backend_sched_reset(sched);
|
||||
|
||||
auto * gf = lctx.graph_init();
|
||||
auto * gf = lctx->graph_init();
|
||||
|
||||
auto res = build_graph_shift(lctx.get_cparams(), lctx.get_ctx_compute(), gf);
|
||||
auto res = build_graph_shift(lctx->get_cparams(), lctx->get_ctx_compute(), gf);
|
||||
if (!res) {
|
||||
LLAMA_LOG_ERROR("%s: failed to build graph for K-shift\n", __func__);
|
||||
return updated;
|
||||
|
|
@ -390,7 +439,7 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
|
|||
|
||||
res->set_inputs(nullptr);
|
||||
|
||||
if (lctx.graph_compute(gf, false) != GGML_STATUS_SUCCESS) {
|
||||
if (lctx->graph_compute(gf, false) != GGML_STATUS_SUCCESS) {
|
||||
LLAMA_LOG_ERROR("%s: failed to compute K-shift\n", __func__);
|
||||
return updated;
|
||||
}
|
||||
|
|
@ -401,56 +450,55 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
|
|||
cells.reset_shift();
|
||||
}}
|
||||
|
||||
if (do_defrag) {
|
||||
if (!dinfo.empty()) {
|
||||
LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
|
||||
|
||||
if (defrag_prepare(lctx.graph_max_nodes())) {
|
||||
ggml_backend_sched_reset(sched);
|
||||
// apply moves:
|
||||
{
|
||||
const auto n_kv = dinfo.ids.size();
|
||||
|
||||
auto * gf = lctx.graph_init();
|
||||
for (uint32_t i = 0; i < n_kv; ++i) {
|
||||
assert(dinfo.ids[i] <= n_kv);
|
||||
|
||||
auto res = build_graph_defrag(lctx.get_cparams(), lctx.get_ctx_compute(), gf);
|
||||
if (!res) {
|
||||
LLAMA_LOG_ERROR("%s: failed to build graph for defrag\n", __func__);
|
||||
return updated;
|
||||
if (dinfo.ids[i] == n_kv) {
|
||||
continue;
|
||||
}
|
||||
|
||||
cells.mv(i, dinfo.ids[i]);
|
||||
}
|
||||
|
||||
if (!ggml_backend_sched_alloc_graph(sched, gf)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to allocate compute graph for defrag\n", __func__);
|
||||
return updated;
|
||||
}
|
||||
|
||||
res->set_inputs(nullptr);
|
||||
|
||||
if (lctx.graph_compute(gf, false) != GGML_STATUS_SUCCESS) {
|
||||
LLAMA_LOG_ERROR("%s: failed to compute defrag\n", __func__);
|
||||
return updated;
|
||||
}
|
||||
|
||||
updated = true;
|
||||
// reset the head so we can find the first free slot during the next ubatch
|
||||
head = 0;
|
||||
}
|
||||
|
||||
do_defrag = false;
|
||||
ggml_backend_sched_reset(sched);
|
||||
|
||||
auto * gf = lctx->graph_init();
|
||||
|
||||
auto res = build_graph_defrag(lctx->get_cparams(), lctx->get_ctx_compute(), gf, dinfo);
|
||||
if (!res) {
|
||||
LLAMA_LOG_ERROR("%s: failed to build graph for defrag\n", __func__);
|
||||
return updated;
|
||||
}
|
||||
|
||||
if (!ggml_backend_sched_alloc_graph(sched, gf)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to allocate compute graph for defrag\n", __func__);
|
||||
return updated;
|
||||
}
|
||||
|
||||
res->set_inputs(nullptr);
|
||||
|
||||
if (lctx->graph_compute(gf, false) != GGML_STATUS_SUCCESS) {
|
||||
LLAMA_LOG_ERROR("%s: failed to compute defrag\n", __func__);
|
||||
return updated;
|
||||
}
|
||||
|
||||
updated = true;
|
||||
}
|
||||
|
||||
return updated;
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified::defrag_sched(float thold) {
|
||||
const auto n_kv = cells.used_max_p1();
|
||||
|
||||
// - do not defrag small contexts (i.e. < 2048 tokens)
|
||||
// - count the padding towards the number of used tokens
|
||||
const float fragmentation = n_kv >= 2048 ? std::max(0.0f, 1.0f - (float(cells.get_used() + n_pad)/n_kv)) : 0.0f;
|
||||
|
||||
// queue defragmentation for next llama_kv_cache_update
|
||||
if (fragmentation > thold) {
|
||||
LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation);
|
||||
|
||||
do_defrag = true;
|
||||
}
|
||||
}
|
||||
|
||||
int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
|
||||
const uint32_t n_tokens = ubatch.n_tokens;
|
||||
|
||||
|
|
@ -597,6 +645,10 @@ uint32_t llama_kv_cache_unified::get_size() const {
|
|||
return cells.size();
|
||||
}
|
||||
|
||||
bool llama_kv_cache_unified::get_has_shift() const {
|
||||
return cells.get_has_shift();
|
||||
}
|
||||
|
||||
uint32_t llama_kv_cache_unified::get_n_kv() const {
|
||||
return std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad)));
|
||||
}
|
||||
|
|
@ -926,12 +978,13 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
|
|||
}
|
||||
|
||||
llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
|
||||
const llama_cparams & cparams,
|
||||
ggml_context * ctx,
|
||||
ggml_cgraph * gf) const {
|
||||
const llama_cparams & cparams,
|
||||
ggml_context * ctx,
|
||||
ggml_cgraph * gf,
|
||||
const defrag_info & dinfo) const {
|
||||
auto res = std::make_unique<llm_graph_result>();
|
||||
|
||||
const auto & ids = defrag_info.ids;
|
||||
const auto & ids = dinfo.ids;
|
||||
|
||||
#if 0
|
||||
// CPU defrag
|
||||
|
|
@ -1072,7 +1125,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
|
|||
return res;
|
||||
}
|
||||
|
||||
bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
|
||||
llama_kv_cache_unified::defrag_info llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) const {
|
||||
const uint32_t n_layer = layers.size();
|
||||
|
||||
const uint32_t n_kv = cells.used_max_p1();
|
||||
|
|
@ -1093,14 +1146,9 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
|
|||
const uint32_t max_moves = (n_max_nodes - 2*n_layer)/(6*n_layer);
|
||||
|
||||
// determine which KV cells to move where
|
||||
//
|
||||
// cell i moves to ids[i]
|
||||
//
|
||||
// if ids[i] == i || ids[i] == n_kv, then cell i is not moved
|
||||
//
|
||||
auto & ids = defrag_info.ids;
|
||||
defrag_info res;
|
||||
auto & ids = res.ids;
|
||||
|
||||
ids.clear();
|
||||
ids.resize(n_kv, n_kv);
|
||||
|
||||
for (uint32_t i0 = 0; i0 < n_used; ++i0) {
|
||||
|
|
@ -1164,11 +1212,6 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
|
|||
// this cell goes to (i0 + nf)
|
||||
ids[i1] = i0 + nf;
|
||||
|
||||
// move the cell meta data
|
||||
cells.mv(i1, i0 + nf);
|
||||
|
||||
head = n_used;
|
||||
|
||||
if (!cont) {
|
||||
n_moves++;
|
||||
cont = true;
|
||||
|
|
@ -1191,14 +1234,14 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
|
|||
}
|
||||
|
||||
if (n_moves == 0) {
|
||||
return false;
|
||||
return {};
|
||||
}
|
||||
|
||||
LLAMA_LOG_DEBUG("%s: (tmp log) KV defrag cell moves: %u\n", __func__, n_moves);
|
||||
|
||||
LLAMA_LOG_DEBUG("%s: expected gf nodes: %u\n", __func__, 6*n_moves*n_layer);
|
||||
|
||||
return true;
|
||||
return res;
|
||||
}
|
||||
|
||||
bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const {
|
||||
|
|
@ -1621,24 +1664,27 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
|
|||
llama_kv_cache_unified_state::llama_kv_cache_unified_state(llama_memory_status status) : status(status) {}
|
||||
|
||||
llama_kv_cache_unified_state::llama_kv_cache_unified_state(
|
||||
llama_memory_status status,
|
||||
llama_kv_cache_unified * kv) : status(status), kv(kv) {
|
||||
n_kv = kv->get_size();
|
||||
head = 0;
|
||||
}
|
||||
llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
|
||||
n_kv = kv->get_size();
|
||||
head = 0;
|
||||
}
|
||||
|
||||
llama_kv_cache_unified_state::llama_kv_cache_unified_state(
|
||||
llama_memory_status status,
|
||||
llama_kv_cache_unified * kv,
|
||||
llama_sbatch sbatch,
|
||||
std::vector<uint32_t> heads,
|
||||
std::vector<llama_ubatch> ubatches)
|
||||
: status(status),
|
||||
kv(kv),
|
||||
sbatch(std::move(sbatch)),
|
||||
heads(std::move(heads)),
|
||||
ubatches(std::move(ubatches)) {
|
||||
llama_kv_cache_unified * kv,
|
||||
llama_context * lctx,
|
||||
bool do_shift,
|
||||
defrag_info dinfo) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), lctx(lctx), do_shift(do_shift), dinfo(std::move(dinfo)) {
|
||||
if (!do_shift && dinfo.empty()) {
|
||||
status = LLAMA_MEMORY_STATUS_NO_UPDATE;
|
||||
}
|
||||
}
|
||||
|
||||
llama_kv_cache_unified_state::llama_kv_cache_unified_state(
|
||||
llama_kv_cache_unified * kv,
|
||||
llama_sbatch sbatch,
|
||||
llama_kv_cache_unified::ubatch_heads heads,
|
||||
std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sbatch(std::move(sbatch)), heads(std::move(heads)), ubatches(std::move(ubatches)) {
|
||||
}
|
||||
|
||||
llama_kv_cache_unified_state::~llama_kv_cache_unified_state() = default;
|
||||
|
||||
|
|
@ -1655,6 +1701,13 @@ bool llama_kv_cache_unified_state::next() {
|
|||
bool llama_kv_cache_unified_state::apply() {
|
||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||
|
||||
// no ubatches -> this is a KV cache update
|
||||
if (ubatches.empty()) {
|
||||
kv->update(lctx, do_shift, dinfo);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
kv->apply_ubatch(heads[i_next], ubatches[i_next]);
|
||||
|
||||
n_kv = kv->get_n_kv();
|
||||
|
|
|
|||
|
|
@ -24,6 +24,19 @@ public:
|
|||
// this callback is used to filter out layers that should not be included in the cache
|
||||
using layer_filter_cb = std::function<bool(int32_t il)>;
|
||||
|
||||
using ubatch_heads = std::vector<uint32_t>;
|
||||
|
||||
struct defrag_info {
|
||||
bool empty() const {
|
||||
return ids.empty();
|
||||
}
|
||||
|
||||
// contains information about which cell moves where:
|
||||
// - cell i moves to ids[i]
|
||||
// - if ids[i] == i || ids[i] == ids.size(), then cell i is not moved
|
||||
std::vector<uint32_t> ids;
|
||||
};
|
||||
|
||||
llama_kv_cache_unified(
|
||||
const llama_model & model,
|
||||
layer_filter_cb && filter,
|
||||
|
|
@ -66,9 +79,7 @@ public:
|
|||
|
||||
llama_memory_state_ptr init_full() override;
|
||||
|
||||
bool update(llama_context & lctx) override;
|
||||
|
||||
void defrag_sched(float thold) override;
|
||||
llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
|
||||
|
||||
bool get_can_shift() const override;
|
||||
|
||||
|
|
@ -83,6 +94,8 @@ public:
|
|||
|
||||
uint32_t get_size() const;
|
||||
|
||||
bool get_has_shift() const;
|
||||
|
||||
//
|
||||
// graph_build API
|
||||
//
|
||||
|
|
@ -103,7 +116,9 @@ public:
|
|||
|
||||
// find places for the provided ubatches in the cache, returns the head locations
|
||||
// return empty vector on failure
|
||||
std::vector<uint32_t> prepare(const std::vector<llama_ubatch> & ubatches);
|
||||
ubatch_heads prepare(const std::vector<llama_ubatch> & ubatches);
|
||||
|
||||
bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo);
|
||||
|
||||
// return the cell position where we can insert the ubatch
|
||||
// return -1 on failure to find a contiguous slot of kv cells
|
||||
|
|
@ -133,8 +148,7 @@ private:
|
|||
ggml_tensor * v;
|
||||
};
|
||||
|
||||
bool do_defrag = false;
|
||||
bool v_trans = true; // the value tensor is transposed
|
||||
bool v_trans = true; // the value tensor is transposed
|
||||
|
||||
// the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
|
||||
// note: this is not part of the KV state and it's only used to speed-up the find_slot() method
|
||||
|
|
@ -160,13 +174,8 @@ private:
|
|||
// model layer id -> KV cache layer id
|
||||
std::unordered_map<int32_t, int32_t> map_layer_ids;
|
||||
|
||||
// defrag
|
||||
struct {
|
||||
std::vector<uint32_t> ids;
|
||||
} defrag_info;
|
||||
|
||||
// return true if cells have been moved
|
||||
bool defrag_prepare(int32_t n_max_nodes);
|
||||
// return non-empty vector if cells have been moved
|
||||
defrag_info defrag_prepare(int32_t n_max_nodes) const;
|
||||
|
||||
size_t total_size() const;
|
||||
|
||||
|
|
@ -192,7 +201,8 @@ private:
|
|||
llm_graph_result_ptr build_graph_defrag(
|
||||
const llama_cparams & cparams,
|
||||
ggml_context * ctx,
|
||||
ggml_cgraph * gf) const;
|
||||
ggml_cgraph * gf,
|
||||
const defrag_info & dinfo) const;
|
||||
|
||||
void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
|
||||
void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
|
||||
|
|
@ -203,20 +213,29 @@ private:
|
|||
|
||||
class llama_kv_cache_unified_state : public llama_memory_state_i {
|
||||
public:
|
||||
// some shorthands
|
||||
using ubatch_heads = llama_kv_cache_unified::ubatch_heads;
|
||||
using defrag_info = llama_kv_cache_unified::defrag_info;
|
||||
|
||||
// used for errors
|
||||
llama_kv_cache_unified_state(llama_memory_status status);
|
||||
|
||||
// used to create a full-cache state
|
||||
llama_kv_cache_unified_state(
|
||||
llama_memory_status status,
|
||||
llama_kv_cache_unified * kv);
|
||||
|
||||
// used to create a state from a batch
|
||||
// used to create an update state
|
||||
llama_kv_cache_unified_state(
|
||||
llama_kv_cache_unified * kv,
|
||||
llama_context * lctx,
|
||||
bool do_shift,
|
||||
defrag_info dinfo);
|
||||
|
||||
// used to create a decode state from a batch
|
||||
llama_kv_cache_unified_state(
|
||||
llama_memory_status status,
|
||||
llama_kv_cache_unified * kv,
|
||||
llama_sbatch sbatch,
|
||||
std::vector<uint32_t> heads,
|
||||
ubatch_heads heads,
|
||||
std::vector<llama_ubatch> ubatches);
|
||||
|
||||
virtual ~llama_kv_cache_unified_state();
|
||||
|
|
@ -253,16 +272,30 @@ public:
|
|||
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
||||
|
||||
private:
|
||||
const llama_memory_status status;
|
||||
llama_memory_status status;
|
||||
|
||||
llama_kv_cache_unified * kv;
|
||||
llama_context * lctx;
|
||||
|
||||
//
|
||||
// update state
|
||||
//
|
||||
|
||||
bool do_shift = false;
|
||||
|
||||
defrag_info dinfo;
|
||||
|
||||
//
|
||||
// batch processing state
|
||||
//
|
||||
|
||||
llama_sbatch sbatch;
|
||||
|
||||
// the index of the next ubatch to process
|
||||
size_t i_next = 0;
|
||||
|
||||
std::vector<uint32_t> heads;
|
||||
ubatch_heads heads;
|
||||
|
||||
std::vector<llama_ubatch> ubatches;
|
||||
|
||||
//
|
||||
|
|
|
|||
|
|
@ -1,12 +1,16 @@
|
|||
#pragma once
|
||||
|
||||
#include "llama.h"
|
||||
#include "llama-io.h"
|
||||
#include "llama-memory.h"
|
||||
|
||||
class llama_io_write_i;
|
||||
class llama_io_read_i;
|
||||
|
||||
struct llama_kv_cache : public llama_memory_i {
|
||||
virtual ~llama_kv_cache() = default;
|
||||
|
||||
// TODO: move the init_ interfaces to llama_memory_i
|
||||
|
||||
// split the input batch into a set of ubatches and verify that they can fit into the cache
|
||||
// return a state object containing the ubatches and KV cache state required to process them
|
||||
// check the llama_memory_state_i::get_status() for the result
|
||||
|
|
@ -19,16 +23,9 @@ struct llama_kv_cache : public llama_memory_i {
|
|||
// simulate full cache, used for allocating worst-case compute buffers
|
||||
virtual llama_memory_state_ptr init_full() = 0;
|
||||
|
||||
// process any pending defrag/shift/etc. operations
|
||||
// optionally call once before processing a new batch
|
||||
// return true if any operations were performed
|
||||
virtual bool update(llama_context & lctx) = 0;
|
||||
|
||||
// schedule a defrag if the fragmentation threshold is exceeded. otherwise, do nothing
|
||||
// TODO: change to
|
||||
// llama_memory_state_ptr init_defrag(float thold) = 0;
|
||||
//
|
||||
virtual void defrag_sched(float thold) = 0;
|
||||
// prepare for any pending memory updates, such as shifts, defrags, etc.
|
||||
// status == LLAMA_MEMORY_STATUS_NO_UPDATE if there is nothing to update
|
||||
virtual llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) = 0;
|
||||
|
||||
// getters
|
||||
virtual bool get_can_shift() const = 0;
|
||||
|
|
|
|||
|
|
@ -1 +1,42 @@
|
|||
#include "llama-memory.h"
|
||||
|
||||
llama_memory_status llama_memory_status_combine(llama_memory_status s0, llama_memory_status s1) {
|
||||
bool has_update = false;
|
||||
|
||||
switch (s0) {
|
||||
case LLAMA_MEMORY_STATUS_SUCCESS:
|
||||
{
|
||||
has_update = true;
|
||||
break;
|
||||
}
|
||||
case LLAMA_MEMORY_STATUS_NO_UPDATE:
|
||||
{
|
||||
break;
|
||||
}
|
||||
case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
|
||||
case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
|
||||
{
|
||||
return s0;
|
||||
}
|
||||
}
|
||||
|
||||
switch (s1) {
|
||||
case LLAMA_MEMORY_STATUS_SUCCESS:
|
||||
{
|
||||
has_update = true;
|
||||
break;
|
||||
}
|
||||
case LLAMA_MEMORY_STATUS_NO_UPDATE:
|
||||
{
|
||||
break;
|
||||
}
|
||||
case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
|
||||
case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
|
||||
{
|
||||
return s1;
|
||||
}
|
||||
}
|
||||
|
||||
// if either status has an update, then the combined status has an update
|
||||
return has_update ? LLAMA_MEMORY_STATUS_SUCCESS : LLAMA_MEMORY_STATUS_NO_UPDATE;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -36,12 +36,19 @@ public:
|
|||
virtual bool get_can_edit() const = 0;
|
||||
};
|
||||
|
||||
using llama_memory_ptr = std::unique_ptr<llama_memory_i>;
|
||||
|
||||
enum llama_memory_status {
|
||||
LLAMA_MEMORY_STATUS_SUCCESS = 0,
|
||||
LLAMA_MEMORY_STATUS_NO_UPDATE,
|
||||
LLAMA_MEMORY_STATUS_FAILED_PREPARE,
|
||||
LLAMA_MEMORY_STATUS_FAILED_COMPUTE,
|
||||
};
|
||||
|
||||
// helper function for combining the status of two memory states
|
||||
// useful for implementing hybrid memory types (e.g. iSWA)
|
||||
llama_memory_status llama_memory_status_combine(llama_memory_status s0, llama_memory_status s1);
|
||||
|
||||
// the interface for managing the memory state during batch processing
|
||||
// this interface is implemented per memory type. see:
|
||||
// - llama_kv_cache_unified_state
|
||||
|
|
@ -69,7 +76,7 @@ public:
|
|||
// get the current ubatch
|
||||
virtual const llama_ubatch & get_ubatch() const = 0;
|
||||
|
||||
// get the status of the memory state
|
||||
// get the status of the memory state - used for error handling and checking if any updates would be applied
|
||||
virtual llama_memory_status get_status() const = 0;
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -961,6 +961,11 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|||
case 46: type = LLM_TYPE_27B; break;
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
|
||||
// ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/config.py#L173
|
||||
hparams.f_attention_scale = type == LLM_TYPE_27B
|
||||
? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0)))
|
||||
: 1.0f / std::sqrt(float(hparams.n_embd_head_k));
|
||||
} break;
|
||||
case LLM_ARCH_GEMMA3:
|
||||
{
|
||||
|
|
@ -981,6 +986,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
|
||||
// ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/config.py#L289
|
||||
hparams.f_attention_scale = type == LLM_TYPE_27B
|
||||
? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0)))
|
||||
: 1.0f / std::sqrt(float(hparams.n_embd_head_k));
|
||||
|
|
@ -8584,14 +8590,7 @@ struct llm_build_gemma2_iswa : public llm_graph_context {
|
|||
cb(Kcur, "Kcur", il);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
// ref: https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e
|
||||
switch (model.type) {
|
||||
case LLM_TYPE_2B:
|
||||
case LLM_TYPE_9B:
|
||||
case LLM_TYPE_27B: Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head))); break;
|
||||
default: GGML_ABORT("fatal error");
|
||||
};
|
||||
cb(Qcur, "Qcur_scaled", il);
|
||||
Qcur = ggml_scale(ctx0, Qcur, hparams.f_attention_scale);
|
||||
|
||||
cur = build_attn(inp_attn, gf,
|
||||
model.layers[il].wo, NULL,
|
||||
|
|
@ -8732,9 +8731,12 @@ struct llm_build_gemma3_iswa : public llm_graph_context {
|
|||
cb(Kcur, "Kcur", il);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
// ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/model.py#L315
|
||||
Qcur = ggml_scale(ctx0, Qcur, hparams.f_attention_scale);
|
||||
|
||||
cur = build_attn(inp_attn, gf,
|
||||
model.layers[il].wo, NULL,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, hparams.f_attention_scale, il);
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
|
||||
}
|
||||
|
||||
cur = build_norm(cur,
|
||||
|
|
|
|||
|
|
@ -70,6 +70,7 @@ struct mtmd_cli_context {
|
|||
llama_model * model;
|
||||
llama_context * lctx;
|
||||
const llama_vocab * vocab;
|
||||
common_sampler * smpl;
|
||||
llama_batch batch;
|
||||
int n_batch;
|
||||
|
||||
|
|
@ -89,8 +90,9 @@ struct mtmd_cli_context {
|
|||
model = llama_init.model.get();
|
||||
lctx = llama_init.context.get();
|
||||
vocab = llama_model_get_vocab(model);
|
||||
smpl = common_sampler_init(model, params.sampling);
|
||||
n_threads = params.cpuparams.n_threads;
|
||||
batch = llama_batch_init(params.n_batch, 0, 1);
|
||||
batch = llama_batch_init(1, 0, 1); // batch for next token generation
|
||||
n_batch = params.n_batch;
|
||||
|
||||
if (!model || !lctx) {
|
||||
|
|
@ -118,6 +120,11 @@ struct mtmd_cli_context {
|
|||
}
|
||||
}
|
||||
|
||||
~mtmd_cli_context() {
|
||||
llama_batch_free(batch);
|
||||
common_sampler_free(smpl);
|
||||
}
|
||||
|
||||
void init_vision_context(common_params & params) {
|
||||
const char * clip_path = params.mmproj.path.c_str();
|
||||
mtmd_context_params mparams = mtmd_context_params_default();
|
||||
|
|
@ -153,7 +160,7 @@ struct mtmd_cli_context {
|
|||
}
|
||||
};
|
||||
|
||||
static int generate_response(mtmd_cli_context & ctx, common_sampler * smpl, int n_predict) {
|
||||
static int generate_response(mtmd_cli_context & ctx, int n_predict) {
|
||||
llama_tokens generated_tokens;
|
||||
for (int i = 0; i < n_predict; i++) {
|
||||
if (i > n_predict || !g_is_generating || g_is_interrupted) {
|
||||
|
|
@ -161,9 +168,9 @@ static int generate_response(mtmd_cli_context & ctx, common_sampler * smpl, int
|
|||
break;
|
||||
}
|
||||
|
||||
llama_token token_id = common_sampler_sample(smpl, ctx.lctx, -1);
|
||||
llama_token token_id = common_sampler_sample(ctx.smpl, ctx.lctx, -1);
|
||||
generated_tokens.push_back(token_id);
|
||||
common_sampler_accept(smpl, token_id, true);
|
||||
common_sampler_accept(ctx.smpl, token_id, true);
|
||||
|
||||
if (llama_vocab_is_eog(ctx.vocab, token_id) || ctx.check_antiprompt(generated_tokens)) {
|
||||
LOG("\n");
|
||||
|
|
@ -261,7 +268,6 @@ int main(int argc, char ** argv) {
|
|||
|
||||
bool is_single_turn = !params.prompt.empty() && !params.image.empty();
|
||||
|
||||
struct common_sampler * smpl = common_sampler_init(ctx.model, params.sampling);
|
||||
int n_predict = params.n_predict < 0 ? INT_MAX : params.n_predict;
|
||||
|
||||
// Ctrl+C handling
|
||||
|
|
@ -300,7 +306,7 @@ int main(int argc, char ** argv) {
|
|||
if (eval_message(ctx, msg, true)) {
|
||||
return 1;
|
||||
}
|
||||
if (!g_is_interrupted && generate_response(ctx, smpl, n_predict)) {
|
||||
if (!g_is_interrupted && generate_response(ctx, n_predict)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
|
|
@ -366,7 +372,7 @@ int main(int argc, char ** argv) {
|
|||
return 1;
|
||||
}
|
||||
if (g_is_interrupted) break;
|
||||
if (generate_response(ctx, smpl, n_predict)) {
|
||||
if (generate_response(ctx, n_predict)) {
|
||||
return 1;
|
||||
}
|
||||
content.clear();
|
||||
|
|
|
|||
|
|
@ -311,6 +311,7 @@ int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx,
|
|||
GGML_ABORT("chunk type not supported");
|
||||
}
|
||||
|
||||
llama_batch_free(text_batch);
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -360,7 +360,7 @@ struct server_task {
|
|||
params.oaicompat_chat_syntax.format = defaults.oaicompat_chat_syntax.format;
|
||||
}
|
||||
params.oaicompat_chat_syntax.reasoning_format = params_base.reasoning_format;
|
||||
params.oaicompat_chat_syntax.reasoning_in_content = params.stream;
|
||||
params.oaicompat_chat_syntax.reasoning_in_content = params.stream && (params_base.reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY);
|
||||
params.oaicompat_chat_syntax.thinking_forced_open = json_value(data, "thinking_forced_open", false);
|
||||
params.oaicompat_chat_syntax.parse_tool_calls = json_value(data, "parse_tool_calls", false);
|
||||
}
|
||||
|
|
@ -2016,6 +2016,11 @@ struct server_context {
|
|||
params_base.n_cache_reuse = 0;
|
||||
SRV_WRN("%s\n", "cache_reuse is not supported by this context, it will be disabled");
|
||||
}
|
||||
|
||||
if (!params_base.speculative.model.path.empty()) {
|
||||
SRV_ERR("%s\n", "err: speculative decode is not supported by this context");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
|
|
@ -3203,9 +3208,7 @@ struct server_context {
|
|||
}
|
||||
} else {
|
||||
// if we don't cache the prompt, we have to remove the entire KV cache
|
||||
llama_kv_self_seq_rm(ctx, slot.id, 0, -1);
|
||||
slot.n_past = 0;
|
||||
slot.cache_tokens.clear(); // TODO: not needed, will be cleared later via "keep_first()"
|
||||
}
|
||||
|
||||
if (slot.n_past > 0 && slot.n_past < (int) slot.cache_tokens.size()) {
|
||||
|
|
@ -3220,7 +3223,6 @@ struct server_context {
|
|||
SLT_WRN(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min, n_swa);
|
||||
SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA, see %s)\n",
|
||||
"https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
|
||||
llama_kv_self_seq_rm(ctx, slot.id, 0, -1);
|
||||
slot.n_past = 0;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -499,13 +499,12 @@ def do_test_calc_result(server: ServerProcess, result_override: str | None, n_pr
|
|||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("n_predict,reasoning_format,stream,expect_reasoning_content,expect_content,hf_repo,template_override", [
|
||||
(128, 'deepseek', CompletionMode.NORMAL, None, "^The sum of 102 and 7 is 109[\\s\\S]*", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
|
||||
(128, None, CompletionMode.NORMAL, None, "^The sum of 102 and 7 is 109[\\s\\S]*", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
|
||||
(1024, 'deepseek', CompletionMode.NORMAL, "I need to calculate the sum of 102 and 7[\\s\\S]*", "To find the sum of[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||
(1024, 'deepseek', CompletionMode.STREAMED, None, "^<think>I need to calculate [\\s\\S]*?</think>To find the sum of [\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||
(1024, 'deepseek', CompletionMode.NORMAL, "First, I [\\s\\S]*", "To find the sum of[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
|
||||
(1024, 'deepseek', CompletionMode.STREAMED, None, "^<think>First, I [\\s\\S]*?</think>To find the sum of[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
|
||||
@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
|
||||
@pytest.mark.parametrize("n_predict,reasoning_format,expect_reasoning_content,expect_content,hf_repo,template_override", [
|
||||
(128, 'deepseek', None, "^The sum of 102 and 7 is 109[\\s\\S]*", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
|
||||
(128, None, None, "^The sum of 102 and 7 is 109[\\s\\S]*", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
|
||||
(1024, 'deepseek', "I need to calculate the sum of 102 and 7[\\s\\S]*", "To find the sum of[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||
(1024, 'deepseek', "First, I [\\s\\S]*", "To find the sum of[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
|
||||
# (1024, 'none', CompletionMode.NORMAL, None, "^(<think>\\s*)?I need[\\s\\S]*?</think>\\s*To find[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||
# (128, 'deepseek', None, "^Okay, let me figure out the sum of 102 and 7[\\s\\S]*", "bartowski/Qwen_QwQ-32B-GGUF:Q4_K_M", None),
|
||||
])
|
||||
|
|
|
|||
|
|
@ -308,10 +308,12 @@ class ServerProcess:
|
|||
stream = data.get('stream', False)
|
||||
if stream:
|
||||
content: list[str] = []
|
||||
reasoning_content: list[str] = []
|
||||
tool_calls: list[dict] = []
|
||||
finish_reason: Optional[str] = None
|
||||
|
||||
content_parts = 0
|
||||
reasoning_content_parts = 0
|
||||
tool_call_parts = 0
|
||||
arguments_parts = 0
|
||||
|
||||
|
|
@ -322,6 +324,10 @@ class ServerProcess:
|
|||
assert len(choice['delta']['content']) > 0, f'Expected non empty content delta!'
|
||||
content.append(choice['delta']['content'])
|
||||
content_parts += 1
|
||||
if choice['delta'].get('reasoning_content') is not None:
|
||||
assert len(choice['delta']['reasoning_content']) > 0, f'Expected non empty reasoning_content delta!'
|
||||
reasoning_content.append(choice['delta']['reasoning_content'])
|
||||
reasoning_content_parts += 1
|
||||
if choice['delta'].get('finish_reason') is not None:
|
||||
finish_reason = choice['delta']['finish_reason']
|
||||
for tc in choice['delta'].get('tool_calls', []):
|
||||
|
|
@ -349,8 +355,10 @@ class ServerProcess:
|
|||
tool_call['function']['name'] = tool_call['function'].get('name', '') + fct['name']
|
||||
if fct.get('arguments') is not None:
|
||||
tool_call['function']['arguments'] += fct['arguments']
|
||||
arguments_parts += 1
|
||||
tool_call_parts += 1
|
||||
|
||||
print(f'Streamed response had {content_parts} content parts, {tool_call_parts} tool call parts incl. {arguments_parts} arguments parts')
|
||||
print(f'Streamed response had {content_parts} content parts, {reasoning_content_parts} reasoning_content parts, {tool_call_parts} tool call parts incl. {arguments_parts} arguments parts')
|
||||
result = dict(
|
||||
choices=[
|
||||
dict(
|
||||
|
|
@ -359,6 +367,7 @@ class ServerProcess:
|
|||
message=dict(
|
||||
role='assistant',
|
||||
content=''.join(content) if content else None,
|
||||
reasoning_content=''.join(reasoning_content) if reasoning_content else None,
|
||||
tool_calls=tool_calls if tool_calls else None,
|
||||
),
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue