mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-11 01:24:36 +00:00
Merge branch 'upstream' into concedo_experimental
# Conflicts: # ggml/src/ggml-rpc/ggml-rpc.cpp # ggml/src/ggml-sycl/common.hpp # ggml/src/ggml-sycl/ggml-sycl.cpp # tests/test-backend-ops.cpp
This commit is contained in:
commit
3f545eadbe
5 changed files with 55 additions and 39 deletions
|
@ -587,15 +587,15 @@ static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_im
|
||||||
}
|
}
|
||||||
|
|
||||||
// implementation of the 2D RoPE without adding a new op in ggml
|
// implementation of the 2D RoPE without adding a new op in ggml
|
||||||
|
// this is not efficient (use double the memory), but works on all backends
|
||||||
|
// TODO: there was a more efficient which relies on ggml_view and ggml_rope_ext_inplace, but the rope inplace does not work well with non-contiguous tensors ; we should fix that and revert back to the original implementation in https://github.com/ggml-org/llama.cpp/pull/13065
|
||||||
static ggml_tensor * build_rope_2d(
|
static ggml_tensor * build_rope_2d(
|
||||||
ggml_cgraph * gf,
|
|
||||||
ggml_context * ctx0,
|
ggml_context * ctx0,
|
||||||
ggml_tensor * cur,
|
ggml_tensor * cur,
|
||||||
ggml_tensor * pos_h,
|
ggml_tensor * pos_h,
|
||||||
ggml_tensor * pos_w,
|
ggml_tensor * pos_w,
|
||||||
const float freq_base
|
const float freq_base
|
||||||
) {
|
) {
|
||||||
ggml_tensor * tmp;
|
|
||||||
const int64_t n_dim = cur->ne[0];
|
const int64_t n_dim = cur->ne[0];
|
||||||
const int64_t n_head = cur->ne[1];
|
const int64_t n_head = cur->ne[1];
|
||||||
const int64_t n_pos = cur->ne[2];
|
const int64_t n_pos = cur->ne[2];
|
||||||
|
@ -604,18 +604,23 @@ static ggml_tensor * build_rope_2d(
|
||||||
// we will have a list of 4 inv_freq: 1e-0, 1e-1, 1e-2, 1e-3
|
// we will have a list of 4 inv_freq: 1e-0, 1e-1, 1e-2, 1e-3
|
||||||
// first half of cur will use 1e-0, 1e-2 (even)
|
// first half of cur will use 1e-0, 1e-2 (even)
|
||||||
// second half of cur will use 1e-1, 1e-3 (odd)
|
// second half of cur will use 1e-1, 1e-3 (odd)
|
||||||
//
|
// the trick here is to rotate just half of n_dim, so inv_freq will automatically be even
|
||||||
// for the first half, the trick here is to rotate n_dim/2, so inv_freq will be even
|
|
||||||
// ^ don't ask me why, it's math! -2(2i) / n_dim == -2i / (n_dim/2)
|
// ^ don't ask me why, it's math! -2(2i) / n_dim == -2i / (n_dim/2)
|
||||||
// then for the second half, we use freq_scale to shift the inv_freq
|
// then for the second half, we use freq_scale to shift the inv_freq
|
||||||
// ^ why? replace (2i) with (2i+1) in the above equation
|
// ^ why? replace (2i) with (2i+1) in the above equation
|
||||||
const float freq_scale_odd = std::pow(freq_base, (float)-2/n_dim);
|
const float freq_scale_odd = std::pow(freq_base, (float)-2/n_dim);
|
||||||
|
|
||||||
// first half
|
// first half
|
||||||
|
ggml_tensor * first;
|
||||||
{
|
{
|
||||||
cur = ggml_rope_ext_inplace(
|
first = ggml_view_3d(ctx0, cur,
|
||||||
|
n_dim/2, n_head, n_pos,
|
||||||
|
ggml_row_size(cur->type, n_dim),
|
||||||
|
ggml_row_size(cur->type, n_dim*n_head),
|
||||||
|
0);
|
||||||
|
first = ggml_rope_ext(
|
||||||
ctx0,
|
ctx0,
|
||||||
cur,
|
first,
|
||||||
pos_h, // positions
|
pos_h, // positions
|
||||||
nullptr, // freq factors
|
nullptr, // freq factors
|
||||||
n_dim/2, // n_dims
|
n_dim/2, // n_dims
|
||||||
|
@ -625,15 +630,17 @@ static ggml_tensor * build_rope_2d(
|
||||||
}
|
}
|
||||||
|
|
||||||
// second half
|
// second half
|
||||||
|
ggml_tensor * second;
|
||||||
{
|
{
|
||||||
tmp = ggml_view_3d(ctx0, cur,
|
second = ggml_view_3d(ctx0, cur,
|
||||||
n_dim/2, n_head, n_pos,
|
n_dim/2, n_head, n_pos,
|
||||||
ggml_row_size(cur->type, n_dim),
|
ggml_row_size(cur->type, n_dim),
|
||||||
ggml_row_size(cur->type, n_dim*n_head),
|
ggml_row_size(cur->type, n_dim*n_head),
|
||||||
n_dim/2 * ggml_element_size(cur));
|
n_dim/2 * ggml_element_size(cur));
|
||||||
tmp = ggml_rope_ext_inplace(
|
second = ggml_cont(ctx0, second); // copy, because ggml_rope don't play well with non-contiguous tensors
|
||||||
|
second = ggml_rope_ext(
|
||||||
ctx0,
|
ctx0,
|
||||||
tmp,
|
second,
|
||||||
pos_w, // positions
|
pos_w, // positions
|
||||||
nullptr, // freq factors
|
nullptr, // freq factors
|
||||||
n_dim/2, // n_dims
|
n_dim/2, // n_dims
|
||||||
|
@ -641,10 +648,9 @@ static ggml_tensor * build_rope_2d(
|
||||||
freq_scale_odd,
|
freq_scale_odd,
|
||||||
0.0f, 1.0f, 0.0f, 0.0f
|
0.0f, 1.0f, 0.0f, 0.0f
|
||||||
);
|
);
|
||||||
// calculate inplace (modify cur directly)
|
|
||||||
ggml_build_forward_expand(gf, tmp);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
cur = ggml_concat(ctx0, first, second, 0);
|
||||||
return cur;
|
return cur;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -713,13 +719,13 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i
|
||||||
struct ggml_tensor * Q = ggml_mul_mat(ctx0, model.layers[il].q_w, cur);
|
struct ggml_tensor * Q = ggml_mul_mat(ctx0, model.layers[il].q_w, cur);
|
||||||
|
|
||||||
Q = ggml_reshape_3d(ctx0, Q, d_head, n_head, num_patches);
|
Q = ggml_reshape_3d(ctx0, Q, d_head, n_head, num_patches);
|
||||||
Q = build_rope_2d(gf, ctx0, Q, pos_h, pos_w, hparams.rope_theta);
|
Q = build_rope_2d(ctx0, Q, pos_h, pos_w, hparams.rope_theta);
|
||||||
Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3));
|
Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3));
|
||||||
|
|
||||||
struct ggml_tensor * K = ggml_mul_mat(ctx0, model.layers[il].k_w, cur);
|
struct ggml_tensor * K = ggml_mul_mat(ctx0, model.layers[il].k_w, cur);
|
||||||
|
|
||||||
K = ggml_reshape_3d(ctx0, K, d_head, n_head, num_patches);
|
K = ggml_reshape_3d(ctx0, K, d_head, n_head, num_patches);
|
||||||
K = build_rope_2d(gf, ctx0, K, pos_h, pos_w, hparams.rope_theta);
|
K = build_rope_2d(ctx0, K, pos_h, pos_w, hparams.rope_theta);
|
||||||
K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3));
|
K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3));
|
||||||
|
|
||||||
struct ggml_tensor * V = ggml_mul_mat(ctx0, model.layers[il].v_w, cur);
|
struct ggml_tensor * V = ggml_mul_mat(ctx0, model.layers[il].v_w, cur);
|
||||||
|
@ -3012,10 +3018,15 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
||||||
const auto & model = ctx->vision_model;
|
const auto & model = ctx->vision_model;
|
||||||
const auto & hparams = model.hparams;
|
const auto & hparams = model.hparams;
|
||||||
|
|
||||||
|
// TODO @ngxson : this is ugly, need to refactor later
|
||||||
|
bool support_dynamic_size = ctx->has_minicpmv_projector
|
||||||
|
|| ctx->has_qwen2vl_merger
|
||||||
|
|| ctx->proj_type == PROJECTOR_TYPE_PIXTRAL;
|
||||||
|
|
||||||
const int image_size = hparams.image_size;
|
const int image_size = hparams.image_size;
|
||||||
int image_size_width = image_size;
|
int image_size_width = image_size;
|
||||||
int image_size_height = image_size;
|
int image_size_height = image_size;
|
||||||
if (ctx->has_minicpmv_projector | ctx->has_qwen2vl_merger) {
|
if (support_dynamic_size) {
|
||||||
image_size_width = imgs.entries[0]->nx;
|
image_size_width = imgs.entries[0]->nx;
|
||||||
image_size_height = imgs.entries[0]->ny;
|
image_size_height = imgs.entries[0]->ny;
|
||||||
}
|
}
|
||||||
|
@ -3027,9 +3038,20 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
||||||
|
|
||||||
{
|
{
|
||||||
struct ggml_tensor * inp_raw = ggml_graph_get_tensor(gf, "inp_raw");
|
struct ggml_tensor * inp_raw = ggml_graph_get_tensor(gf, "inp_raw");
|
||||||
float * data = (float *)malloc(ggml_nbytes(inp_raw));
|
std::vector<float> inp_data(ggml_nelements(inp_raw));
|
||||||
|
float * data = inp_data.data();
|
||||||
|
|
||||||
|
// layout of data (note: the channel dim is unrolled to better visualize the layout):
|
||||||
|
//
|
||||||
|
// ┌──W──┐
|
||||||
|
// │ H │ channel = R
|
||||||
|
// ├─────┤ │
|
||||||
|
// │ H │ channel = G
|
||||||
|
// ├─────┤ │
|
||||||
|
// │ H │ channel = B
|
||||||
|
// └─────┘ │
|
||||||
|
// ──────┘ x B
|
||||||
|
|
||||||
// TODO @ngxson : this whole code block is ugly, will need to be refactored
|
|
||||||
for (size_t i = 0; i < imgs.entries.size(); i++) {
|
for (size_t i = 0; i < imgs.entries.size(); i++) {
|
||||||
const int nx = imgs.entries[i]->nx;
|
const int nx = imgs.entries[i]->nx;
|
||||||
const int ny = imgs.entries[i]->ny;
|
const int ny = imgs.entries[i]->ny;
|
||||||
|
@ -3044,17 +3066,19 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
||||||
const int n = nx * ny;
|
const int n = nx * ny;
|
||||||
|
|
||||||
for (int b = 0; b < batch_size; b++) {
|
for (int b = 0; b < batch_size; b++) {
|
||||||
for (int k = 0; k < 3; k++) {
|
float * batch_entry = data + b * (3*n);
|
||||||
for (int y = 0; y < ny; y++) {
|
for (int y = 0; y < ny; y++) {
|
||||||
for (int x = 0; x < nx; x++) {
|
for (int x = 0; x < nx; x++) {
|
||||||
data[(b * 3 * n) + k * n + y * nx + x] = imgs.entries[b]->buf[3 * (y * nx + x) + k];
|
size_t base_src = 3*(y * nx + x); // idx of the first channel
|
||||||
}
|
size_t base_dst = y * nx + x; // idx of the first channel
|
||||||
|
batch_entry[ base_dst] = imgs.entries[b]->buf[base_src ];
|
||||||
|
batch_entry[1*n + base_dst] = imgs.entries[b]->buf[base_src + 1];
|
||||||
|
batch_entry[2*n + base_dst] = imgs.entries[b]->buf[base_src + 2];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ggml_backend_tensor_set(inp_raw, data, 0, ggml_nbytes(inp_raw));
|
ggml_backend_tensor_set(inp_raw, data, 0, ggml_nbytes(inp_raw));
|
||||||
free(data);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (ctx->has_minicpmv_projector) {
|
if (ctx->has_minicpmv_projector) {
|
||||||
|
|
|
@ -7,7 +7,7 @@
|
||||||
extern "C" {
|
extern "C" {
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#define RPC_PROTO_MAJOR_VERSION 1
|
#define RPC_PROTO_MAJOR_VERSION 2
|
||||||
#define RPC_PROTO_MINOR_VERSION 0
|
#define RPC_PROTO_MINOR_VERSION 0
|
||||||
#define RPC_PROTO_PATCH_VERSION 0
|
#define RPC_PROTO_PATCH_VERSION 0
|
||||||
#define GGML_RPC_MAX_SERVERS 16
|
#define GGML_RPC_MAX_SERVERS 16
|
||||||
|
|
|
@ -469,8 +469,7 @@ ggml_tensor * llama_context::build_rope_shift(
|
||||||
ggml_tensor * shift,
|
ggml_tensor * shift,
|
||||||
ggml_tensor * factors,
|
ggml_tensor * factors,
|
||||||
float freq_base,
|
float freq_base,
|
||||||
float freq_scale,
|
float freq_scale) const {
|
||||||
ggml_backend_buffer * bbuf) const {
|
|
||||||
const auto & n_ctx_orig = cparams.n_ctx_orig_yarn;
|
const auto & n_ctx_orig = cparams.n_ctx_orig_yarn;
|
||||||
|
|
||||||
const auto & yarn_ext_factor = cparams.yarn_ext_factor;
|
const auto & yarn_ext_factor = cparams.yarn_ext_factor;
|
||||||
|
@ -492,17 +491,7 @@ ggml_tensor * llama_context::build_rope_shift(
|
||||||
// dequantize to f32 -> RoPE -> quantize back
|
// dequantize to f32 -> RoPE -> quantize back
|
||||||
tmp = ggml_cast(ctx0, cur, GGML_TYPE_F32);
|
tmp = ggml_cast(ctx0, cur, GGML_TYPE_F32);
|
||||||
|
|
||||||
if (bbuf) {
|
tmp = ggml_rope_ext(ctx0, tmp,
|
||||||
for (const auto & backend : backends) {
|
|
||||||
// Figure out which backend KV cache belongs to
|
|
||||||
if (ggml_backend_supports_buft(backend.get(), ggml_backend_buffer_get_type(bbuf))) {
|
|
||||||
ggml_backend_sched_set_tensor_backend(sched.get(), tmp, backend.get());
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
tmp = ggml_rope_ext_inplace(ctx0, tmp,
|
|
||||||
shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
|
yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
|
||||||
|
|
||||||
|
@ -582,7 +571,7 @@ llm_graph_result_ptr llama_context::build_kv_self_shift(
|
||||||
ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
|
ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
|
||||||
0);
|
0);
|
||||||
|
|
||||||
ggml_tensor * cur = build_rope_shift(ctx0, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l, kv_self->k_l[il]->buffer);
|
ggml_tensor * cur = build_rope_shift(ctx0, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l);
|
||||||
|
|
||||||
ggml_build_forward_expand(gf, cur);
|
ggml_build_forward_expand(gf, cur);
|
||||||
}
|
}
|
||||||
|
|
|
@ -170,8 +170,7 @@ private:
|
||||||
ggml_tensor * shift,
|
ggml_tensor * shift,
|
||||||
ggml_tensor * factors,
|
ggml_tensor * factors,
|
||||||
float freq_base,
|
float freq_base,
|
||||||
float freq_scale,
|
float freq_scale) const;
|
||||||
ggml_backend_buffer * bbuf) const;
|
|
||||||
|
|
||||||
llm_graph_result_ptr build_kv_self_shift(
|
llm_graph_result_ptr build_kv_self_shift(
|
||||||
ggml_context * ctx0,
|
ggml_context * ctx0,
|
||||||
|
|
|
@ -803,6 +803,10 @@ ggml_tensor * llm_graph_context::build_ffn(
|
||||||
|
|
||||||
if (down) {
|
if (down) {
|
||||||
cur = build_lora_mm(down, cur);
|
cur = build_lora_mm(down, cur);
|
||||||
|
if (arch == LLM_ARCH_GLM4) {
|
||||||
|
// GLM4 seems to have numerical issues with half-precision accumulators
|
||||||
|
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (down_b) {
|
if (down_b) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue