diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 0862099da..a47d7df6f 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -1778,6 +1778,12 @@ class LlamaModel(TextModel): model_arch = gguf.MODEL_ARCH.LLAMA undo_permute = True + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # fix for SmolVLM2, missing `num_attention_heads` in config.json + if self.hf_arch == "VLlama3ForCausalLM": + self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 32) + def set_vocab(self): try: self._set_vocab_sentencepiece() @@ -2123,6 +2129,9 @@ class DeciModel(TextModel): # if n_heads_in_group is not None, then # _num_kv_heads[il] is num_attention_head // n_heads_in_group and # _num_heads[il] is num_attention_head + # ***dummy layer*** for nemotron 253B + # if n_heads_in_group is None and ffn_mult is None + # then _num_kv_heads[il] is 0 and _num_heads[il] is 0 and _ffn_dims is 0 for il in range(len(_block_configs)): if _block_configs[il]["attention"]["n_heads_in_group"] is None: if _block_configs[il]["attention"]["replace_with_linear"] is True: @@ -2134,7 +2143,10 @@ class DeciModel(TextModel): else: self._num_kv_heads.append(self.hparams["num_attention_heads"] // _block_configs[il]["attention"]["n_heads_in_group"]) self._num_heads.append(self.hparams["num_attention_heads"]) - _ffn_multipliers.append(_block_configs[il]["ffn"]["ffn_mult"]) + if _block_configs[il]["ffn"]["ffn_mult"] is None: # dummy layer + _ffn_multipliers.append(0.0) + else: + _ffn_multipliers.append(_block_configs[il]["ffn"]["ffn_mult"]) assert self.block_count == len(self._num_kv_heads) assert self.block_count == len(self._num_heads) assert self.block_count == len(_ffn_multipliers) @@ -5674,7 +5686,12 @@ class BailingMoeModel(TextModel): rope_dim = hparams.get("head_dim") or hparams["hidden_size"] // hparams["num_attention_heads"] self.gguf_writer.add_rope_dimension_count(rope_dim) - self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE) + if (self.hparams.get("rope_scaling") or {}).get("type") == "yarn" and "factor" in self.hparams["rope_scaling"]: + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) + self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"]) + self.gguf_writer.add_rope_scaling_orig_ctx_len(self.hparams["rope_scaling"]["original_max_position_embeddings"]) + else: + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE) self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"]) self.gguf_writer.add_vocab_size(hparams["vocab_size"]) self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"]) diff --git a/ggml/src/ggml-cpu/ggml-cpu-quants.c b/ggml/src/ggml-cpu/ggml-cpu-quants.c index 820c143bd..bb52be799 100644 --- a/ggml/src/ggml-cpu/ggml-cpu-quants.c +++ b/ggml/src/ggml-cpu/ggml-cpu-quants.c @@ -6597,7 +6597,118 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi } *s = hsum_float_8(acc); +#elif defined(__VXE__) || defined(__VXE2__) + uint32_t aux[3]; + uint32_t utmp[4]; + const int32x4_t v_z = vec_splat_s32(0); + const uint8x16_t v_3m = vec_splat_u8(0x03); + + const uint8x16_t v_0c = vec_splat_u8(1); + const uint8x16_t v_1c = vec_sl(v_0c, 1); + const uint8x16_t v_2c = vec_sl(v_0c, 2); + const uint8x16_t v_3c = vec_sl(v_0c, 3); + + uint8x16_t q3h[4]; + uint8x16_t q3b[2]; + int8x16_t q3bytes[4]; + int8x16_t q8bytes[4]; + uint8x16_t qhbits[2]; + + float sum = 0; + + for (int i = 0; i < nb; ++i) { + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + + const uint8_t * restrict x0l = x[i].qs; + const uint8_t * restrict x0h = x[i].hmask; + const int8_t * restrict y0 = y[i].qs; + + qhbits[0] = vec_xl(0 , x0h); + qhbits[1] = vec_xl(16, x0h); + + int32_t isum = 0; + + memcpy(aux, x[i].scales, 12); + utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4); + utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4); + utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4); + utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4); + + int8_t * scale = (int8_t *)utmp; + for (int j = 0; j < 16; ++j) scale[j] -= 32; + + for (int j = 0; j < QK_K/128; ++j) { + int32x4_t isum0, isum1, isum2, isum3; + + q3b[0] = vec_xl(0 , x0l); + q3b[1] = vec_xl(16, x0l); + x0l += 32; + + q8bytes[0] = vec_xl(0 , y0); + q8bytes[1] = vec_xl(16 , y0); + q8bytes[2] = vec_xl(32 , y0); + q8bytes[3] = vec_xl(48 , y0); + q8bytes[4] = vec_xl(64 , y0); + q8bytes[5] = vec_xl(80 , y0); + q8bytes[6] = vec_xl(96 , y0); + q8bytes[7] = vec_xl(112, y0); + y0 += 128; + + q3h[0] = vec_sl(vec_andc(v_0c, qhbits[0]), 2); + q3h[1] = vec_sl(vec_andc(v_0c, qhbits[1]), 2); + q3h[2] = vec_sl(vec_andc(v_1c, qhbits[0]), 1); + q3h[3] = vec_sl(vec_andc(v_1c, qhbits[1]), 1); + + q3bytes[0] = vec_sub((int8x16_t)vec_and(q3b[0], v_3m), (int8x16_t)q3h[0]); + q3bytes[1] = vec_sub((int8x16_t)vec_and(q3b[1], v_3m), (int8x16_t)q3h[1]); + q3bytes[2] = vec_sub((int8x16_t)vec_and(vec_sr(q3b[0], 2), v_3m), (int8x16_t)q3h[2]); + q3bytes[3] = vec_sub((int8x16_t)vec_and(vec_sr(q3b[1], 2), v_3m), (int8x16_t)q3h[3]); + + isum0 = ggml_vec_dot(v_z, q3bytes[0], q8bytes[0]); + isum1 = ggml_vec_dot(v_z, q3bytes[1], q8bytes[1]); + isum2 = ggml_vec_dot(v_z, q3bytes[2], q8bytes[2]); + isum3 = ggml_vec_dot(v_z, q3bytes[3], q8bytes[3]); + + isum += (isum0[0] + isum0[1] + isum0[2] + isum0[3]) * scale[0]; + isum += (isum1[0] + isum1[1] + isum1[2] + isum1[3]) * scale[1]; + isum += (isum2[0] + isum2[1] + isum2[2] + isum2[3]) * scale[2]; + isum += (isum3[0] + isum3[1] + isum3[2] + isum3[3]) * scale[3]; + + scale += 4; + + q3h[0] = vec_andc(v_2c, qhbits[0]); + q3h[1] = vec_andc(v_2c, qhbits[1]); + q3h[2] = vec_sr(vec_andc(v_3c, qhbits[0]), 1); + q3h[3] = vec_sr(vec_andc(v_3c, qhbits[1]), 1); + + q3bytes[0] = vec_sub((int8x16_t)vec_and(vec_sr(q3b[0], 4), v_3m), (int8x16_t)q3h[0]); + q3bytes[1] = vec_sub((int8x16_t)vec_and(vec_sr(q3b[1], 4), v_3m), (int8x16_t)q3h[1]); + q3bytes[2] = vec_sub((int8x16_t)vec_and(vec_sr(q3b[0], 6), v_3m), (int8x16_t)q3h[2]); + q3bytes[3] = vec_sub((int8x16_t)vec_and(vec_sr(q3b[1], 6), v_3m), (int8x16_t)q3h[3]); + + isum0 = ggml_vec_dot(v_z, q3bytes[0], q8bytes[4]); + isum1 = ggml_vec_dot(v_z, q3bytes[1], q8bytes[5]); + isum2 = ggml_vec_dot(v_z, q3bytes[2], q8bytes[6]); + isum3 = ggml_vec_dot(v_z, q3bytes[3], q8bytes[7]); + + isum += (isum0[0] + isum0[1] + isum0[2] + isum0[3]) * scale[0]; + isum += (isum1[0] + isum1[1] + isum1[2] + isum1[3]) * scale[1]; + isum += (isum2[0] + isum2[1] + isum2[2] + isum2[3]) * scale[2]; + isum += (isum3[0] + isum3[1] + isum3[2] + isum3[3]) * scale[3]; + + scale += 4; + + if (j == 0) { + qhbits[0] = vec_sr(qhbits[0], 4); + qhbits[1] = vec_sr(qhbits[1], 4); + } + } + + sum += d * isum; + } + + *s = sum; #else // scalar version // This function is written like this so the compiler can manage to vectorize most of it diff --git a/ggml/src/ggml-cpu/ggml-cpu.cpp b/ggml/src/ggml-cpu/ggml-cpu.cpp index 39bd935cb..a6d597c79 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.cpp +++ b/ggml/src/ggml-cpu/ggml-cpu.cpp @@ -11,24 +11,26 @@ #include #ifdef GGML_USE_CPU_HBM -#include "ggml-cpu-hbm.h" +# include "ggml-cpu-hbm.h" #endif #ifdef GGML_USE_CPU_KLEIDIAI -#include "kleidiai/kleidiai.h" -#endif - -#if defined(__APPLE__) -#include -#include +# include "kleidiai/kleidiai.h" #endif #if defined(_WIN32) -#define WIN32_LEAN_AND_MEAN -#ifndef NOMINMAX - #define NOMINMAX +# define WIN32_LEAN_AND_MEAN +# ifndef NOMINMAX +# define NOMINMAX +# endif +# include +#else +# include #endif -#include + +#if defined(__APPLE__) +# include +# include #endif // ggml-backend interface @@ -70,8 +72,10 @@ static ggml_backend_buffer_type_t * ggml_backend_cpu_device_get_extra_buffers_ty } static bool ggml_backend_cpu_is_extra_buffer_type(ggml_backend_buffer_type_t buft) { - for (auto extra : ggml_backend_cpu_get_extra_buffers_type()) { - if (extra && extra == buft) return true; + for (auto * extra : ggml_backend_cpu_get_extra_buffers_type()) { + if (extra && extra == buft) { + return true; + } } return false; } @@ -330,9 +334,18 @@ static const char * ggml_backend_cpu_device_get_description(ggml_backend_dev_t d } static void ggml_backend_cpu_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { - // TODO - *free = 0; - *total = 0; +#ifdef _WIN32 + MEMORYSTATUSEX status; + status.dwLength = sizeof(status); + GlobalMemoryStatusEx(&status); + *total = status.ullTotalPhys; + *free = status.ullAvailPhys; +#else + long pages = sysconf(_SC_PHYS_PAGES); + long page_size = sysconf(_SC_PAGE_SIZE); + *total = pages * page_size; + *free = *total; +#endif GGML_UNUSED(dev); } diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 18383537e..0e11bc255 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -2637,6 +2637,7 @@ static __global__ void mul_mat_q( ids_dst_shared[j] = j; } + __syncthreads(); // On AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead: #if (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA @@ -2665,6 +2666,7 @@ static __global__ void mul_mat_q( return; } + // __syncthreads(); // There is no previous tile that could cause a race condition. #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps*WARP_SIZE) { const int j = j0 + threadIdx.y*WARP_SIZE + threadIdx.x; @@ -2675,6 +2677,7 @@ static __global__ void mul_mat_q( ids_dst_shared[j] = ids_dst[col_low + jt*mmq_x + j]; } + __syncthreads(); } offset_y += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int)); @@ -2741,6 +2744,7 @@ static __global__ void mul_mat_q( continue; } + __syncthreads(); #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps*WARP_SIZE) { const int j = j0 + threadIdx.y*WARP_SIZE + threadIdx.x; @@ -2751,6 +2755,7 @@ static __global__ void mul_mat_q( ids_dst_shared[j] = ids_dst[col_low + jt*mmq_x + j]; } + __syncthreads(); } offset_y += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int)); @@ -2806,6 +2811,7 @@ static __global__ void mul_mat_q( } // The memory layout for the fixup buffer is always contiguous, therefore reset ids: + __syncthreads(); #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps*WARP_SIZE) { const int j = j0 + threadIdx.y*WARP_SIZE + threadIdx.x; @@ -2816,6 +2822,7 @@ static __global__ void mul_mat_q( ids_dst_shared[j] = j; } + __syncthreads(); } offset_y += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int)); @@ -2952,6 +2959,7 @@ static __global__ void mul_mat_q_stream_k_fixup( for (int j = threadIdx.y*WARP_SIZE + threadIdx.x; j < mmq_x; j += nwarps*WARP_SIZE) { ids_dst_shared[j] = ids_dst[col_low + j]; } + __syncthreads(); const int offset_dst = it*mmq_y; dst += offset_dst; diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 947b0d9e1..8b64bbe0d 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -356,11 +356,17 @@ struct vk_device_struct { vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT]; vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT]; vk_pipeline pipeline_acc_f32; - vk_pipeline pipeline_add_f32, pipeline_add_f32_norepeat; - vk_pipeline pipeline_add_f16_f32_f16, pipeline_add_f16_f32_f16_norepeat; - vk_pipeline pipeline_sub_f32, pipeline_sub_f32_norepeat; - vk_pipeline pipeline_mul_f32, pipeline_mul_f32_norepeat; - vk_pipeline pipeline_div_f32, pipeline_div_f32_norepeat; + + // [src0 0=fp32,1=fp16][src1 0=fp32,1=fp16][dst 0=fp32,1=fp16] + vk_pipeline pipeline_add[2][2][2]; + vk_pipeline pipeline_add_norepeat[2][2][2]; + vk_pipeline pipeline_sub[2][2][2]; + vk_pipeline pipeline_sub_norepeat[2][2][2]; + vk_pipeline pipeline_mul[2][2][2]; + vk_pipeline pipeline_mul_norepeat[2][2][2]; + vk_pipeline pipeline_div[2][2][2]; + vk_pipeline pipeline_div_norepeat[2][2][2]; + vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32; vk_pipeline pipeline_upscale_f32; vk_pipeline pipeline_scale_f32; @@ -370,8 +376,8 @@ struct vk_device_struct { vk_pipeline pipeline_clamp_f32; vk_pipeline pipeline_pad_f32; vk_pipeline pipeline_repeat_f32, pipeline_repeat_back_f32; - vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f32_bf16; - vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f32_bf16; + vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f16_f32, pipeline_cpy_f32_bf16; + vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16; vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT]; vk_pipeline pipeline_cpy_quant_f32[GGML_TYPE_COUNT]; vk_pipeline pipeline_norm_f32; @@ -379,14 +385,17 @@ struct vk_device_struct { vk_pipeline pipeline_rms_norm_f32; vk_pipeline pipeline_rms_norm_back_f32; vk_pipeline pipeline_l2_norm_f32; - vk_pipeline pipeline_gelu_f32; - vk_pipeline pipeline_gelu_quick_f32; - vk_pipeline pipeline_silu_f32; - vk_pipeline pipeline_silu_back_f32; - vk_pipeline pipeline_relu_f32; + + // [src/dst 0=fp32,1=fp16] + vk_pipeline pipeline_gelu[2]; + vk_pipeline pipeline_gelu_quick[2]; + vk_pipeline pipeline_silu[2]; + vk_pipeline pipeline_relu[2]; + vk_pipeline pipeline_tanh[2]; + vk_pipeline pipeline_sigmoid[2]; + vk_pipeline pipeline_leaky_relu_f32; - vk_pipeline pipeline_tanh_f32; - vk_pipeline pipeline_sigmoid_f32; + vk_pipeline pipeline_silu_back_f32; vk_pipeline pipeline_diag_mask_inf_f32; vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16; vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512; @@ -2524,11 +2533,13 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f16, "cpy_f16_f16", cpy_f16_f16_len, cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f32, "cpy_f16_f32", cpy_f16_f32_len, cpy_f16_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_bf16,"cpy_f32_bf16",cpy_f32_bf16_len,cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f32, "contig_cpy_f32_f32", contig_cpy_f32_f32_len, contig_cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f16, "contig_cpy_f32_f16", contig_cpy_f32_f16_len, contig_cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f16, "contig_cpy_f16_f16", contig_cpy_f16_f16_len, contig_cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f32, "contig_cpy_f16_f32", contig_cpy_f16_f32_len, contig_cpy_f16_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_bf16,"contig_cpy_f32_bf16",contig_cpy_f32_bf16_len,contig_cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); if (device->float_controls_rte_fp16) { @@ -2554,20 +2565,32 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q8_0], "cpy_q8_0_f32", cpy_q8_0_f32_len, cpy_q8_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_IQ4_NL], "cpy_iq4_nl_f32", cpy_iq4_nl_f32_len, cpy_iq4_nl_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_add_f32, "add_f32", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1); - ggml_vk_create_pipeline(device, device->pipeline_add_f32_norepeat, "add_f32_norepeat", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1); - ggml_vk_create_pipeline(device, device->pipeline_add_f16_f32_f16, "add_f16_f32_f16", add_f16_f32_f16_len, add_f16_f32_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1); - ggml_vk_create_pipeline(device, device->pipeline_add_f16_f32_f16_norepeat, "add_f16_f32_f16_norepeat", add_f16_f32_f16_len, add_f16_f32_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1); + auto get_suffix = [](bool src0_f16, bool src1_f16, bool dst_f16) { + std::string s; + s += std::string(src0_f16 ? "_f16" : "_f32"); + s += std::string(src1_f16 ? "_f16" : "_f32"); + s += std::string(dst_f16 ? "_f16" : "_f32"); + return s; + }; + +#define CREATE_BINARY(name, namemod, spec) \ + for (int s0 : {0,1}) for (int s1 : {0,1}) for (int d : {0,1}) \ + ggml_vk_create_pipeline(device, device->pipeline_ ## name ## namemod[s0][s1][d], \ + #name + get_suffix(s0, s1, d) + #namemod, name ## _len[s0][s1][d], name ## _data[s0][s1][d], \ + "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, spec, 1); + + CREATE_BINARY(add, , {0}) + CREATE_BINARY(add, _norepeat, {1}) + CREATE_BINARY(sub, , {0}) + CREATE_BINARY(sub, _norepeat, {1}) + CREATE_BINARY(mul, , {0}) + CREATE_BINARY(mul, _norepeat, {1}) + CREATE_BINARY(div, , {0}) + CREATE_BINARY(div, _norepeat, {1}) +#undef CREATE_BINARY ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_sub_f32, "sub_f32", sub_f32_len, sub_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1); - ggml_vk_create_pipeline(device, device->pipeline_sub_f32_norepeat, "sub_f32_norepeat", sub_f32_len, sub_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1); - ggml_vk_create_pipeline(device, device->pipeline_mul_f32, "mul_f32", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1); - ggml_vk_create_pipeline(device, device->pipeline_mul_f32_norepeat, "mul_f32_norepeat", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1); - ggml_vk_create_pipeline(device, device->pipeline_div_f32, "div_f32", div_f32_len, div_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1); - ggml_vk_create_pipeline(device, device->pipeline_div_f32_norepeat, "div_f32_norepeat", div_f32_len, div_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1); - ggml_vk_create_pipeline(device, device->pipeline_concat_f32, "concat_f32", concat_f32_len, concat_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_concat_f16, "concat_f16", concat_f16_len, concat_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_concat_i32, "concat_i32", concat_i32_len, concat_i32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); @@ -2587,14 +2610,20 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_repeat_f32, "repeat_f32", repeat_f32_len, repeat_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_repeat_back_f32, "repeat_back_f32", repeat_back_f32_len, repeat_back_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_gelu_f32, "gelu_f32", gelu_f32_len, gelu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_gelu_quick_f32, "gelu_quick_f32", gelu_quick_f32_len, gelu_quick_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_silu_f32, "silu_f32", silu_f32_len, silu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32", silu_back_f32_len, silu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_relu_f32, "relu_f32", relu_f32_len, relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); +#define CREATE_UNARY(name) \ + ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \ + ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + + CREATE_UNARY(gelu) + CREATE_UNARY(gelu_quick) + CREATE_UNARY(silu) + CREATE_UNARY(relu) + CREATE_UNARY(tanh) + CREATE_UNARY(sigmoid) +#undef CREATE_UNARY + ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_tanh_f32, "tanh_f32", tanh_f32_len, tanh_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_sigmoid_f32, "sigmoid_f32", sigmoid_f32_len, sigmoid_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32", silu_back_f32_len, silu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true); @@ -4528,6 +4557,13 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_cpy_f16_f16; } } + if (src->type == GGML_TYPE_F16 && to == GGML_TYPE_F32) { + if (contig) { + return ctx->device->pipeline_contig_cpy_f16_f32; + } else { + return ctx->device->pipeline_cpy_f16_f32; + } + } if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_BF16) { if (contig) { return ctx->device->pipeline_contig_cpy_f32_bf16; @@ -5918,26 +5954,37 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const } return nullptr; case GGML_OP_ADD: - if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_f32_norepeat : ctx->device->pipeline_add_f32; - } - if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) { - return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_f16_f32_f16_norepeat : ctx->device->pipeline_add_f16_f32_f16; - } - return nullptr; case GGML_OP_SUB: - if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_sub_f32_norepeat : ctx->device->pipeline_sub_f32; - } - return nullptr; case GGML_OP_MUL: - if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_mul_f32_norepeat : ctx->device->pipeline_mul_f32; - } - return nullptr; case GGML_OP_DIV: - if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_div_f32_norepeat : ctx->device->pipeline_div_f32; + if ((src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) || + (src1->type != GGML_TYPE_F32 && src1->type != GGML_TYPE_F16) || + (dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16)) { + return nullptr; + } + switch (op) { + case GGML_OP_ADD: + { + auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_norepeat : ctx->device->pipeline_add; + return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16]; + } + case GGML_OP_SUB: + { + auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_sub_norepeat : ctx->device->pipeline_sub; + return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16]; + } + case GGML_OP_MUL: + { + auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_mul_norepeat : ctx->device->pipeline_mul; + return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16]; + } + case GGML_OP_DIV: + { + auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_div_norepeat : ctx->device->pipeline_div; + return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16]; + } + default: + break; } return nullptr; case GGML_OP_CONCAT: @@ -6031,37 +6078,25 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const } return nullptr; case GGML_OP_UNARY: + if ((src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) || + (dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16) || + (src0->type != dst->type)) { + return nullptr; + } + switch (ggml_get_unary_op(dst)) { case GGML_UNARY_OP_SILU: - if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - return ctx->device->pipeline_silu_f32; - } - break; + return ctx->device->pipeline_silu[dst->type == GGML_TYPE_F16]; case GGML_UNARY_OP_GELU: - if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - return ctx->device->pipeline_gelu_f32; - } - break; + return ctx->device->pipeline_gelu[dst->type == GGML_TYPE_F16]; case GGML_UNARY_OP_GELU_QUICK: - if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - return ctx->device->pipeline_gelu_quick_f32; - } - break; + return ctx->device->pipeline_gelu_quick[dst->type == GGML_TYPE_F16]; case GGML_UNARY_OP_RELU: - if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - return ctx->device->pipeline_relu_f32; - } - break; + return ctx->device->pipeline_relu[dst->type == GGML_TYPE_F16]; case GGML_UNARY_OP_TANH: - if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - return ctx->device->pipeline_tanh_f32; - } - break; + return ctx->device->pipeline_tanh[dst->type == GGML_TYPE_F16]; case GGML_UNARY_OP_SIGMOID: - if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - return ctx->device->pipeline_sigmoid_f32; - } - break; + return ctx->device->pipeline_sigmoid[dst->type == GGML_TYPE_F16]; default: break; } @@ -9447,7 +9482,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_UNARY_OP_RELU: case GGML_UNARY_OP_TANH: case GGML_UNARY_OP_SIGMOID: - return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; + return ggml_is_contiguous(op->src[0]) && + (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && + (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && + (op->src[0]->type == op->type); default: return false; } @@ -9627,6 +9665,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm } if (src1_type == GGML_TYPE_F32) { switch (src0_type) { + case GGML_TYPE_F16: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -9665,6 +9704,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_SUB: case GGML_OP_MUL: case GGML_OP_DIV: + return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && + (op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16) && + (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16); case GGML_OP_SILU_BACK: case GGML_OP_RMS_NORM_BACK: case GGML_OP_SQR: diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp b/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp index 52a19b62a..4f806270c 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp @@ -17,5 +17,5 @@ void main() { return; } - data_d[i] = max(float(data_a[i]), 0); + data_d[i] = D_TYPE(max(float(data_a[i]), 0)); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp b/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp index 776581e2c..5c9e5c350 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp @@ -16,5 +16,5 @@ void main() { if (i >= p.KX) { return; } - data_d[i] = D_TYPE(1. / (1 + exp(-1. *data_a[i]))); + data_d[i] = D_TYPE(1. / (1 + exp(-1. * float(data_a[i])))); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp b/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp index 495f966bd..8a6f868f5 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp @@ -16,5 +16,5 @@ void main() { if (i >= p.KX) { return; } - data_d[i] = D_TYPE(1. - 2. / (exp(2.*data_a[i]) + 1.)); + data_d[i] = D_TYPE(1. - 2. / (exp(2.*float(data_a[i])) + 1.)); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 7a4cd5d47..851e38292 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -499,10 +499,12 @@ void process_shaders() { string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("cpy_f32_f16", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}); string_to_spv("cpy_f16_f16", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); + string_to_spv("cpy_f16_f32", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); string_to_spv("cpy_f32_bf16","copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "uint16_t"}, {"DATA_D_BF16", "1"}}); string_to_spv("contig_cpy_f32_f32", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("contig_cpy_f32_f16", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}); string_to_spv("contig_cpy_f16_f16", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); + string_to_spv("contig_cpy_f16_f32", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); string_to_spv("contig_cpy_f32_bf16","contig_copy.comp",{{"A_TYPE", "float"}, {"D_TYPE", "uint16_t"}, {"DATA_D_BF16", "1"}}); for (std::string t : {"q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) { @@ -511,8 +513,26 @@ void process_shaders() { string_to_spv("cpy_" + t + "_f32", "copy_from_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); } - string_to_spv("add_f32", "add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); - string_to_spv("add_f16_f32_f16", "add.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}}); + auto get_type_str = [](bool f16) { + return f16 ? "float16_t" : "float"; + }; + auto get_suffix = [](bool src0_f16, bool src1_f16, bool dst_f16) { + std::string s; + s += std::string(src0_f16 ? "_f16" : "_f32"); + s += std::string(src1_f16 ? "_f16" : "_f32"); + s += std::string(dst_f16 ? "_f16" : "_f32"); + return s; + }; + for (std::string op : {"add", "sub", "mul", "div"}) { + for (auto src0_f16 : {false, true}) { + for (auto src1_f16 : {false, true}) { + for (auto dst_f16 : {false, true}) { + auto name = op + get_suffix(src0_f16, src1_f16, dst_f16); + string_to_spv(name.c_str(), op + ".comp", {{"A_TYPE", get_type_str(src0_f16)}, {"B_TYPE", get_type_str(src1_f16)}, {"D_TYPE", get_type_str(dst_f16)}, {"FLOAT_TYPE", "float"}}); + } + } + } + } string_to_spv("sub_f32", "sub.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); @@ -547,14 +567,21 @@ void process_shaders() { string_to_spv("upscale_f32", "upscale.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); - string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); - string_to_spv("gelu_quick_f32", "gelu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); - string_to_spv("silu_f32", "silu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); - string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); - string_to_spv("relu_f32", "relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); - string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); - string_to_spv("tanh_f32", "tanh.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); - string_to_spv("sigmoid_f32", "sigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("gelu_f16", "gelu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("gelu_quick_f16", "gelu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("gelu_quick_f32", "gelu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("silu_f16", "silu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("silu_f32", "silu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("relu_f16", "relu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("relu_f32", "relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("tanh_f16", "tanh.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("tanh_f32", "tanh.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("sigmoid_f16", "sigmoid.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("sigmoid_f32", "sigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + + string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("diag_mask_inf_f32", "diag_mask_inf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); @@ -655,7 +682,12 @@ void write_output_files() { std::remove(path.c_str()); } } - + for (const char *op : {"add", "sub", "mul", "div"}) { + fprintf(hdr, "extern unsigned char *%s_data[2][2][2];\n", op); + fprintf(hdr, "extern uint64_t %s_len[2][2][2];\n", op); + fprintf(src, "unsigned char *%s_data[2][2][2] = {{{%s_f32_f32_f32_data, %s_f32_f32_f16_data}, {%s_f32_f16_f32_data, %s_f32_f16_f16_data}}, {{%s_f16_f32_f32_data, %s_f16_f32_f16_data}, {%s_f16_f16_f32_data, %s_f16_f16_f16_data}}};\n", op, op, op, op, op, op, op, op, op); + fprintf(src, "uint64_t %s_len[2][2][2] = {{{%s_f32_f32_f32_len, %s_f32_f32_f16_len}, {%s_f32_f16_f32_len, %s_f32_f16_f16_len}}, {{%s_f16_f32_f32_len, %s_f16_f32_f16_len}, {%s_f16_f16_f32_len, %s_f16_f16_f16_len}}};\n", op, op, op, op, op, op, op, op, op); + } fclose(hdr); fclose(src); } diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 2b089f84a..003b0172c 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -977,15 +977,12 @@ class TensorNameMap: "visual.blocks.{bid}.norm2", # qwen2vl ), - # some namings are messed up because the original llava code swapped fc1 and fc2 - # we have no better way to fix it, just be careful - # new models like pixtral use the correct naming MODEL_TENSOR.V_ENC_FFN_UP: ( "vision_tower.vision_model.encoder.layers.{bid}.mlp.fc1", "vpm.encoder.layers.{bid}.mlp.fc1", - "model.vision_model.encoder.layers.{bid}.mlp.fc2", # SmolVLM, gemma3 (note: name is swapped) + "model.vision_model.encoder.layers.{bid}.mlp.fc1", # SmolVLM, gemma3 "vision_tower.transformer.layers.{bid}.feed_forward.up_proj", # pixtral - "visual.blocks.{bid}.mlp.fc2", # qwen2vl + "visual.blocks.{bid}.mlp.fc1", # qwen2vl "visual.blocks.{bid}.mlp.up_proj", # qwen2.5vl ), @@ -997,9 +994,9 @@ class TensorNameMap: MODEL_TENSOR.V_ENC_FFN_DOWN: ( "vision_tower.vision_model.encoder.layers.{bid}.mlp.fc2", "vpm.encoder.layers.{bid}.mlp.fc2", - "model.vision_model.encoder.layers.{bid}.mlp.fc1", # SmolVLM, gemma3 (note: name is swapped) + "model.vision_model.encoder.layers.{bid}.mlp.fc2", # SmolVLM, gemma3 "vision_tower.transformer.layers.{bid}.feed_forward.down_proj", # pixtral - "visual.blocks.{bid}.mlp.fc1", # qwen2vl + "visual.blocks.{bid}.mlp.fc2", # qwen2vl "visual.blocks.{bid}.mlp.down_proj", # qwen2.5vl ), diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 4c8b59740..74dfda8be 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -85,6 +85,7 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_236B: return "236B"; case LLM_TYPE_290B: return "290B"; case LLM_TYPE_314B: return "314B"; + case LLM_TYPE_405B: return "405B"; case LLM_TYPE_671B: return "671B"; case LLM_TYPE_SMALL: return "0.1B"; case LLM_TYPE_MEDIUM: return "0.4B"; @@ -587,6 +588,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { switch (hparams.n_layer) { case 32: type = LLM_TYPE_7B; break; case 80: type = LLM_TYPE_70B; break; + case 162: type = LLM_TYPE_405B; break; default: type = LLM_TYPE_UNKNOWN; } } break; @@ -1905,7 +1907,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + if (n_ff > 0) { + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + } if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); @@ -1915,9 +1919,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); } - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + if (n_ff > 0) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } // optional MLP bias layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); @@ -4808,6 +4814,7 @@ struct llm_build_deci : public llm_graph_context { ggml_tensor * inpSA = inpL; const int64_t n_head_kv = hparams.n_head_kv(il); const int64_t n_head = hparams.n_head(il); + const int64_t n_ff = hparams.n_ff(il); if (n_head == 0) { // attention-free layer of Llama-3_1-Nemotron-51B @@ -4883,6 +4890,11 @@ struct llm_build_deci : public llm_graph_context { inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } + // FFN-free layer of Llama-3_1-Nemotron-Ultra-253B + if (n_head == 0 && n_ff == 0) { + continue; + } + // For Granite architecture if (hparams.f_residual_scale) { cur = ggml_scale(ctx0, cur, hparams.f_residual_scale); diff --git a/src/llama-model.h b/src/llama-model.h index 4c7e7a335..815fa11eb 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -76,6 +76,7 @@ enum llm_type { LLM_TYPE_236B, LLM_TYPE_290B, LLM_TYPE_314B, + LLM_TYPE_405B, LLM_TYPE_671B, LLM_TYPE_SMALL, LLM_TYPE_MEDIUM, diff --git a/tests/test-mtmd-c-api.c b/tests/test-mtmd-c-api.c new file mode 100644 index 000000000..02e762e6a --- /dev/null +++ b/tests/test-mtmd-c-api.c @@ -0,0 +1,63 @@ +#include +#include + +#include "mtmd.h" + +int main(void) { + printf("\n\nTesting libmtmd C API...\n"); + printf("--------\n\n"); + + struct mtmd_context_params params = mtmd_context_params_default(); + printf("Default image marker: %s\n", params.image_marker); + + mtmd_input_chunks * chunks = mtmd_test_create_input_chunks(); + + if (!chunks) { + fprintf(stderr, "Failed to create input chunks\n"); + return 1; + } + + size_t n_chunks = mtmd_input_chunks_size(chunks); + printf("Number of chunks: %zu\n", n_chunks); + assert(n_chunks > 0); + + for (size_t i = 0; i < n_chunks; i++) { + const mtmd_input_chunk * chunk = mtmd_input_chunks_get(chunks, i); + assert(chunk != NULL); + enum mtmd_input_chunk_type type = mtmd_input_chunk_get_type(chunk); + printf("Chunk %zu type: %d\n", i, type); + + if (type == MTMD_INPUT_CHUNK_TYPE_TEXT) { + size_t n_tokens; + const llama_token * tokens = mtmd_input_chunk_get_tokens_text(chunk, &n_tokens); + printf(" Text chunk with %zu tokens\n", n_tokens); + assert(tokens != NULL); + assert(n_tokens > 0); + for (size_t j = 0; j < n_tokens; j++) { + assert(tokens[j] >= 0); + printf(" > Token %zu: %d\n", j, tokens[j]); + } + + } else if (type == MTMD_INPUT_CHUNK_TYPE_IMAGE) { + const mtmd_image_tokens * image_tokens = mtmd_input_chunk_get_tokens_image(chunk); + size_t n_tokens = mtmd_image_tokens_get_n_tokens(image_tokens); + size_t nx = mtmd_image_tokens_get_nx(image_tokens); + size_t ny = mtmd_image_tokens_get_ny(image_tokens); + const char * id = mtmd_image_tokens_get_id(image_tokens); + assert(n_tokens > 0); + assert(nx > 0); + assert(ny > 0); + assert(id != NULL); + printf(" Image chunk with %zu tokens\n", n_tokens); + printf(" Image size: %zu x %zu\n", nx, ny); + printf(" Image ID: %s\n", id); + } + } + + // Free the chunks + mtmd_input_chunks_free(chunks); + + printf("\n\nDONE: test libmtmd C API...\n"); + + return 0; +} diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt index 119d7c50a..06a96af55 100644 --- a/tools/CMakeLists.txt +++ b/tools/CMakeLists.txt @@ -27,13 +27,13 @@ else() add_subdirectory(run) add_subdirectory(tokenize) add_subdirectory(tts) + add_subdirectory(llava) + if (GGML_RPC) + add_subdirectory(rpc) + endif() if (NOT GGML_BACKEND_DL) # these examples use the backends directly and cannot be built with dynamic loading add_subdirectory(cvector-generator) add_subdirectory(export-lora) - add_subdirectory(llava) - if (GGML_RPC) - add_subdirectory(rpc) - endif() endif() endif() diff --git a/tools/llava/clip-impl.h b/tools/llava/clip-impl.h index b575ca4d7..fb780e9de 100644 --- a/tools/llava/clip-impl.h +++ b/tools/llava/clip-impl.h @@ -75,6 +75,8 @@ #define TN_MM_PROJECTOR "mm.model.fc.weight" // idefics3 #define TN_MM_PATCH_MERGER "mm.patch_merger.weight" // mistral small 3.1 #define TN_TOK_IMG_BREAK "v.token_embd.img_break" // pixtral +#define TN_TOK_GLM_BOI "adapter.boi" // glm-edge (these embeddings are not in text model) +#define TN_TOK_GLM_EOI "adapter.eoi" // glm-edge (these embeddings are not in text model) // mimicpmv #define TN_MINICPMV_POS_EMBD_K "resampler.pos_embed_k" @@ -231,6 +233,15 @@ struct clip_image_u8_batch { struct clip_image_f32_batch { std::vector entries; + + clip_image_f32_batch clone() const { + clip_image_f32_batch new_batch; + new_batch.entries.reserve(entries.size()); + for (const auto & entry : entries) { + new_batch.entries.emplace_back(new clip_image_f32(*entry)); + } + return new_batch; + } }; // diff --git a/tools/llava/clip.cpp b/tools/llava/clip.cpp index d2c66b6df..4f84ae5b2 100644 --- a/tools/llava/clip.cpp +++ b/tools/llava/clip.cpp @@ -169,8 +169,8 @@ enum patch_merge_type { struct clip_hparams { int32_t image_size; int32_t patch_size; - int32_t hidden_size; - int32_t n_intermediate; + int32_t n_embd; + int32_t n_ff; int32_t projection_dim; int32_t n_head; int32_t n_layer; @@ -205,12 +205,6 @@ struct clip_layer { struct ggml_tensor * ln_1_w = nullptr; struct ggml_tensor * ln_1_b = nullptr; - // ff - struct ggml_tensor * ff_i_w = nullptr; // legacy naming - struct ggml_tensor * ff_i_b = nullptr; // legacy naming - struct ggml_tensor * ff_o_w = nullptr; // legacy naming - struct ggml_tensor * ff_o_b = nullptr; // legacy naming - struct ggml_tensor * ff_up_w = nullptr; struct ggml_tensor * ff_up_b = nullptr; struct ggml_tensor * ff_gate_w = nullptr; @@ -218,9 +212,6 @@ struct clip_layer { struct ggml_tensor * ff_down_w = nullptr; struct ggml_tensor * ff_down_b = nullptr; - struct ggml_tensor * ff_g_w = NULL; - struct ggml_tensor * ff_g_b = NULL; - // layernorm 2 struct ggml_tensor * ln_2_w = nullptr; struct ggml_tensor * ln_2_b = nullptr; @@ -263,9 +254,11 @@ struct clip_vision_model { struct ggml_tensor * mm_4_w = nullptr; struct ggml_tensor * mm_4_b = nullptr; - //GLMV-Edge projection + // GLMV-Edge projection struct ggml_tensor * mm_model_adapter_conv_w = nullptr; struct ggml_tensor * mm_model_adapter_conv_b = nullptr; + struct ggml_tensor * mm_glm_tok_boi = nullptr; + struct ggml_tensor * mm_glm_tok_eoi = nullptr; // MobileVLM projection struct ggml_tensor * mm_model_mlp_1_w = nullptr; @@ -411,9 +404,9 @@ static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_im const int patch_size = hparams.patch_size; const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size)); - const int hidden_size = hparams.hidden_size; + const int n_embd = hparams.n_embd; const int n_head = hparams.n_head; - const int d_head = hidden_size / n_head; + const int d_head = n_embd / n_head; const int n_layer = hparams.n_layer; const float eps = hparams.eps; @@ -434,7 +427,7 @@ static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_im ggml_set_input(inp_raw); struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1); - inp = ggml_reshape_2d(ctx0, inp, num_patches, hidden_size); + inp = ggml_reshape_2d(ctx0, inp, num_patches, n_embd); inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp)); inp = ggml_add(ctx0, inp, model.patch_bias); @@ -479,7 +472,7 @@ static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_im KQV = ggml_reshape_3d(ctx0, KQV, d_head, num_patches, n_head); KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3); - cur = ggml_cont_2d(ctx0, KQV, hidden_size, num_patches); + cur = ggml_cont_2d(ctx0, KQV, n_embd, num_patches); } // attention output @@ -496,14 +489,14 @@ static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_im cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_2_w), model.layers[il].ln_2_b); } - cur = ggml_mul_mat(ctx0, model.layers[il].ff_i_w, cur); - cur = ggml_add(ctx0, cur, model.layers[il].ff_i_b); + cur = ggml_mul_mat(ctx0, model.layers[il].ff_up_w, cur); + cur = ggml_add(ctx0, cur, model.layers[il].ff_up_b); // siglip uses gelu cur = ggml_gelu(ctx0, cur); - cur = ggml_mul_mat(ctx0, model.layers[il].ff_o_w, cur); - cur = ggml_add(ctx0, cur, model.layers[il].ff_o_b); + cur = ggml_mul_mat(ctx0, model.layers[il].ff_down_w, cur); + cur = ggml_add(ctx0, cur, model.layers[il].ff_down_b); // residual 2 cur = ggml_add(ctx0, embeddings, cur); @@ -527,11 +520,11 @@ static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_im const int kernel_size = patches_per_image / tokens_per_side; embeddings = ggml_cont(ctx0, ggml_transpose(ctx0, embeddings)); - embeddings = ggml_reshape_4d(ctx0, embeddings, patches_per_image, patches_per_image, hidden_size, batch_size); + embeddings = ggml_reshape_4d(ctx0, embeddings, patches_per_image, patches_per_image, n_embd, batch_size); // doing a pool2d to reduce the number of output tokens to 256 embeddings = ggml_pool_2d(ctx0, embeddings, GGML_OP_POOL_AVG, kernel_size, kernel_size, kernel_size, kernel_size, 0, 0); - embeddings = ggml_reshape_3d(ctx0, embeddings, embeddings->ne[0] * embeddings->ne[0], hidden_size, batch_size); + embeddings = ggml_reshape_3d(ctx0, embeddings, embeddings->ne[0] * embeddings->ne[0], n_embd, batch_size); embeddings = ggml_cont(ctx0, ggml_transpose(ctx0, embeddings)); // apply norm before projection @@ -660,9 +653,9 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i const int n_patches_x = image_size_width / patch_size; const int n_patches_y = image_size_height / patch_size; const int num_patches = n_patches_x * n_patches_y; - const int hidden_size = hparams.hidden_size; + const int n_embd = hparams.n_embd; const int n_head = hparams.n_head; - const int d_head = hidden_size / n_head; + const int d_head = n_embd / n_head; const int n_layer = hparams.n_layer; const float eps = hparams.eps; const int n_merge = hparams.spatial_merge_size; @@ -692,7 +685,7 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i ggml_set_input(pos_w); struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1); - inp = ggml_reshape_2d(ctx0, inp, num_patches, hidden_size); + inp = ggml_reshape_2d(ctx0, inp, num_patches, n_embd); inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp)); struct ggml_tensor * embeddings = inp; @@ -733,7 +726,7 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i KQV = ggml_reshape_3d(ctx0, KQV, d_head, num_patches, n_head); KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3); - cur = ggml_cont_2d(ctx0, KQV, hidden_size, num_patches); + cur = ggml_cont_2d(ctx0, KQV, n_embd, num_patches); cur = ggml_mul_mat(ctx0, model.layers[il].o_w, cur); } @@ -776,8 +769,8 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i cur = ggml_mul(ctx0, ggml_rms_norm(ctx0, cur, eps), model.mm_input_norm_w); // reshape image tokens to 2D grid - cur = ggml_reshape_3d(ctx0, cur, hidden_size, n_patches_x, n_patches_y); - cur = ggml_permute(ctx0, cur, 2, 0, 1, 3); // [x, y, hidden_size] + cur = ggml_reshape_3d(ctx0, cur, n_embd, n_patches_x, n_patches_y); + cur = ggml_permute(ctx0, cur, 2, 0, 1, 3); // [x, y, n_embd] cur = ggml_cont(ctx0, cur); // torch.nn.functional.unfold is just an im2col under the hood @@ -785,7 +778,7 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i ggml_tensor * kernel = ggml_view_3d(ctx0, cur, n_merge, n_merge, cur->ne[2], 0, 0, 0); cur = ggml_im2col(ctx0, kernel, cur, n_merge, n_merge, 0, 0, 1, 1, true, inp->type); - // project to hidden_size + // project to n_embd cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], cur->ne[1] * cur->ne[2]); cur = ggml_mul_mat(ctx0, model.mm_patch_merger_w, cur); embeddings = cur; @@ -808,9 +801,9 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i // arrangement of the [IMG_BREAK] token { // not efficient, but works - // the trick is to view the embeddings as a 3D tensor with shape [hidden_size, n_patches_per_row, n_rows] + // the trick is to view the embeddings as a 3D tensor with shape [n_embd, n_patches_per_row, n_rows] // and then concatenate the [IMG_BREAK] token to the end of each row, aka n_patches_per_row dimension - // after the concatenation, we have a tensor with shape [hidden_size, n_patches_per_row + 1, n_rows] + // after the concatenation, we have a tensor with shape [n_embd, n_patches_per_row + 1, n_rows] const int p_y = n_merge > 0 ? n_patches_y / n_merge : n_patches_y; const int p_x = n_merge > 0 ? n_patches_x / n_merge : n_patches_x; @@ -850,9 +843,9 @@ static ggml_cgraph * clip_image_build_graph_qwen25vl(clip_ctx * ctx, const clip_ const int patches_h = image_size_height / patch_size; const int num_positions = num_patches + (model.class_embedding ? 1 : 0); const int num_position_ids = num_positions * 4; // m-rope requires 4 dim per position - const int hidden_size = hparams.hidden_size; + const int n_embd = hparams.n_embd; const int n_head = hparams.n_head; - const int d_head = hidden_size / n_head; + const int d_head = n_embd / n_head; const int n_layer = hparams.n_layer; const float eps = hparams.eps; @@ -887,14 +880,14 @@ static ggml_cgraph * clip_image_build_graph_qwen25vl(clip_ctx * ctx, const clip_ inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 2, 0, 3)); // [w, h, c, b] -> [c, w, h, b] inp = ggml_reshape_4d( ctx0, inp, - hidden_size * 2, patches_w / 2, patches_h, batch_size); + n_embd * 2, patches_w / 2, patches_h, batch_size); inp = ggml_reshape_4d( ctx0, inp, - hidden_size * 2, patches_w / 2, 2, batch_size * (patches_h / 2)); + n_embd * 2, patches_w / 2, 2, batch_size * (patches_h / 2)); inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 0, 2, 1, 3)); inp = ggml_reshape_3d( ctx0, inp, - hidden_size, patches_w * patches_h, batch_size); + n_embd, patches_w * patches_h, batch_size); if (model.patch_bias) { // inp = ggml_add(ctx0, inp, ggml_repeat(ctx0, model.patch_bias, inp)); @@ -927,11 +920,11 @@ static ggml_cgraph * clip_image_build_graph_qwen25vl(clip_ctx * ctx, const clip_ ggml_set_name(window_mask, "window_mask"); ggml_set_input(window_mask); - // embeddings shape: [hidden_size, patches_w * patches_h, batch_size] + // embeddings shape: [n_embd, patches_w * patches_h, batch_size] GGML_ASSERT(batch_size == 1); - embeddings = ggml_reshape_2d(ctx0, embeddings, hidden_size * 4, patches_w * patches_h * batch_size / 4); + embeddings = ggml_reshape_2d(ctx0, embeddings, n_embd * 4, patches_w * patches_h * batch_size / 4); embeddings = ggml_get_rows(ctx0, embeddings, inv_window_idx); - embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size, patches_w * patches_h, batch_size); + embeddings = ggml_reshape_3d(ctx0, embeddings, n_embd, patches_w * patches_h, batch_size); } // loop over layers @@ -984,7 +977,7 @@ static ggml_cgraph * clip_image_build_graph_qwen25vl(clip_ctx * ctx, const clip_ KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_positions, n_head, batch_size); KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3); - cur = ggml_cont_3d(ctx0, KQV, hidden_size, num_positions, batch_size); + cur = ggml_cont_3d(ctx0, KQV, n_embd, num_positions, batch_size); } // attention output @@ -1001,11 +994,11 @@ static ggml_cgraph * clip_image_build_graph_qwen25vl(clip_ctx * ctx, const clip_ // mlp // ffn_up - auto cur_up = ggml_mul_mat(ctx0, model.layers[il].ff_o_w, cur); - cur_up = ggml_add(ctx0, cur_up, model.layers[il].ff_o_b); + auto cur_up = ggml_mul_mat(ctx0, model.layers[il].ff_up_w, cur); + cur_up = ggml_add(ctx0, cur_up, model.layers[il].ff_up_b); - auto cur_gate = ggml_mul_mat(ctx0, model.layers[il].ff_g_w, cur); - cur_gate = ggml_add(ctx0, cur_gate, model.layers[il].ff_g_b); + auto cur_gate = ggml_mul_mat(ctx0, model.layers[il].ff_gate_w, cur); + cur_gate = ggml_add(ctx0, cur_gate, model.layers[il].ff_gate_b); // TODO : only 2 of these 3 are actually used, should we remove one of them? if (ctx->use_gelu) { cur_gate = ggml_gelu_inplace(ctx0, cur_gate); @@ -1017,8 +1010,8 @@ static ggml_cgraph * clip_image_build_graph_qwen25vl(clip_ctx * ctx, const clip_ cur = ggml_mul(ctx0, cur_gate, cur_up); // ffn_down - cur = ggml_mul_mat(ctx0, model.layers[il].ff_i_w, cur); - cur = ggml_add(ctx0, cur, model.layers[il].ff_i_b); + cur = ggml_mul_mat(ctx0, model.layers[il].ff_down_w, cur); + cur = ggml_add(ctx0, cur, model.layers[il].ff_down_b); // residual 2 cur = ggml_add(ctx0, embeddings, cur); @@ -1034,7 +1027,7 @@ static ggml_cgraph * clip_image_build_graph_qwen25vl(clip_ctx * ctx, const clip_ embeddings = ggml_mul(ctx0, embeddings, model.post_ln_w); } - embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size * 4, num_positions / 4, batch_size); + embeddings = ggml_reshape_3d(ctx0, embeddings, n_embd * 4, num_positions / 4, batch_size); embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings); embeddings = ggml_add(ctx0, embeddings, model.mm_0_b); @@ -1051,7 +1044,7 @@ static ggml_cgraph * clip_image_build_graph_qwen25vl(clip_ctx * ctx, const clip_ ggml_set_name(window_idx, "window_idx"); ggml_set_input(window_idx); - // embeddings shape: [hidden_size, patches_w * patches_h, batch_size] + // embeddings shape: [n_embd, patches_w * patches_h, batch_size] GGML_ASSERT(batch_size == 1); embeddings = ggml_reshape_2d(ctx0, embeddings, hparams.projection_dim, patches_w * patches_h / 4); embeddings = ggml_get_rows(ctx0, embeddings, window_idx); @@ -1097,9 +1090,9 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im const int patches_h = image_size_height / patch_size; const int num_positions = num_patches + (model.class_embedding ? 1 : 0); const int num_position_ids = ctx->proj_type == PROJECTOR_TYPE_QWEN2VL ? num_positions * 4 : num_positions; - const int hidden_size = hparams.hidden_size; + const int n_embd = hparams.n_embd; const int n_head = hparams.n_head; - const int d_head = hidden_size / n_head; + const int d_head = n_embd / n_head; const float eps = hparams.eps; int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4}; @@ -1137,17 +1130,17 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 2, 0, 3)); // [w, h, c, b] -> [c, w, h, b] inp = ggml_reshape_4d( ctx0, inp, - hidden_size * 2, patches_w / 2, patches_h, batch_size); + n_embd * 2, patches_w / 2, patches_h, batch_size); inp = ggml_reshape_4d( ctx0, inp, - hidden_size * 2, patches_w / 2, 2, batch_size * (patches_h / 2)); + n_embd * 2, patches_w / 2, 2, batch_size * (patches_h / 2)); inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 0, 2, 1, 3)); inp = ggml_reshape_3d( ctx0, inp, - hidden_size, patches_w * patches_h, batch_size); + n_embd, patches_w * patches_h, batch_size); } else { - inp = ggml_reshape_3d(ctx0, inp, num_patches, hidden_size, batch_size); + inp = ggml_reshape_3d(ctx0, inp, num_patches, n_embd, batch_size); inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 0, 2, 3)); } @@ -1160,7 +1153,7 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im // concat class_embeddings and patch_embeddings if (model.class_embedding) { - embeddings = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, num_positions, batch_size); + embeddings = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd, num_positions, batch_size); embeddings = ggml_scale(ctx0, embeddings, 0.0f); // set to all zeros embeddings = ggml_acc(ctx0, embeddings, model.class_embedding, embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], 0); @@ -1257,7 +1250,7 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_positions, n_head, batch_size); KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3); - cur = ggml_cont_3d(ctx0, KQV, hidden_size, num_positions, batch_size); + cur = ggml_cont_3d(ctx0, KQV, n_embd, num_positions, batch_size); } // attention output @@ -1275,8 +1268,8 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_2_w), model.layers[il].ln_2_b); } - cur = ggml_mul_mat(ctx0, model.layers[il].ff_i_w, cur); - cur = ggml_add(ctx0, cur, model.layers[il].ff_i_b); + cur = ggml_mul_mat(ctx0, model.layers[il].ff_up_w, cur); + cur = ggml_add(ctx0, cur, model.layers[il].ff_up_b); if (ctx->use_gelu) { cur = ggml_gelu_inplace(ctx0, cur); @@ -1286,8 +1279,8 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im cur = ggml_gelu_quick_inplace(ctx0, cur); } - cur = ggml_mul_mat(ctx0, model.layers[il].ff_o_w, cur); - cur = ggml_add(ctx0, cur, model.layers[il].ff_o_b); + cur = ggml_mul_mat(ctx0, model.layers[il].ff_down_w, cur); + cur = ggml_add(ctx0, cur, model.layers[il].ff_down_b); // residual 2 cur = ggml_add(ctx0, embeddings, cur); @@ -1519,9 +1512,9 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im } { // attention - int hidden_size = clip_n_mmproj_embd(ctx); + int n_embd = clip_n_mmproj_embd(ctx); const int d_head = 128; - int n_head = hidden_size/d_head; + int n_head = n_embd/d_head; int num_query = 96; if (ctx->minicpmv_version == 2) { num_query = 96; @@ -1551,7 +1544,7 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ); KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_query, n_head, batch_size); KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3); - KQV = ggml_cont_3d(ctx0, KQV, hidden_size, num_query, batch_size); + KQV = ggml_cont_3d(ctx0, KQV, n_embd, num_query, batch_size); embeddings = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_o_w, KQV), model.mm_model_attn_o_b); } @@ -1584,10 +1577,17 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im embeddings = ggml_mul(ctx0, embeddings,x); embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_3_w, embeddings); } + // arrangement of BOI/EOI token embeddings + // note: these embeddings are not present in text model, hence we cannot process them as text tokens + // see: https://huggingface.co/THUDM/glm-edge-v-2b/blob/main/siglip.py#L53 + { + embeddings = ggml_concat(ctx0, model.mm_glm_tok_boi, embeddings, 1); // BOI + embeddings = ggml_concat(ctx0, embeddings, model.mm_glm_tok_eoi, 1); // EOI + } } else if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL) { - embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size * 4, num_positions / 4, batch_size); + embeddings = ggml_reshape_3d(ctx0, embeddings, n_embd * 4, num_positions / 4, batch_size); embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings); embeddings = ggml_add(ctx0, embeddings, model.mm_0_b); @@ -1727,9 +1727,9 @@ struct clip_model_loader { get_bool(KEY_USE_GELU, ctx_clip.use_gelu, false); get_bool(KEY_USE_SILU, ctx_clip.use_silu, false); - get_u32(KEY_N_EMBD, hparams.hidden_size); + get_u32(KEY_N_EMBD, hparams.n_embd); get_u32(KEY_N_HEAD, hparams.n_head); - get_u32(KEY_N_FF, hparams.n_intermediate); + get_u32(KEY_N_FF, hparams.n_ff); get_u32(KEY_N_BLOCK, hparams.n_layer); get_u32(KEY_PROJ_DIM, hparams.projection_dim); get_f32(KEY_LAYER_NORM_EPS, hparams.eps); @@ -1848,6 +1848,7 @@ struct clip_model_loader { } void load_tensors() { + auto & hparams = ctx_clip.vision_model.hparams; std::map tensor_offset; std::vector tensors_to_load; @@ -1901,8 +1902,8 @@ struct clip_model_loader { vision_model.position_embeddings = get_tensor(string_format(TN_POS_EMBD, "v"), false); // layers - vision_model.layers.resize(vision_model.hparams.n_layer); - for (int il = 0; il < vision_model.hparams.n_layer; ++il) { + vision_model.layers.resize(hparams.n_layer); + for (int il = 0; il < hparams.n_layer; ++il) { auto & layer = vision_model.layers[il]; layer.k_w = get_tensor(string_format(TN_ATTN_K, "v", il, "weight")); layer.q_w = get_tensor(string_format(TN_ATTN_Q, "v", il, "weight")); @@ -1925,13 +1926,18 @@ struct clip_model_loader { layer.ff_down_w = get_tensor(string_format(TN_FFN_DOWN, "v", il, "weight")); layer.ff_down_b = get_tensor(string_format(TN_FFN_DOWN, "v", il, "bias"), false); - // legacy naming (the in and out is reversed! don't ask me why) - layer.ff_i_w = layer.ff_down_w; - layer.ff_o_w = layer.ff_up_w; - layer.ff_g_w = layer.ff_gate_w; - layer.ff_i_b = layer.ff_down_b; - layer.ff_o_b = layer.ff_up_b; - layer.ff_g_b = layer.ff_gate_b; + // some models already exported with legacy (incorrect) naming which is quite messy, let's fix it here + // note: Qwen model converted from the old surgery script has n_ff = 0, so we cannot use n_ff to check! + if (layer.ff_up_w && layer.ff_down_w && layer.ff_down_w->ne[0] == hparams.n_embd) { + // swap up and down weights + ggml_tensor * tmp = layer.ff_up_w; + layer.ff_up_w = layer.ff_down_w; + layer.ff_down_w = tmp; + // swap up and down biases + tmp = layer.ff_up_b; + layer.ff_up_b = layer.ff_down_b; + layer.ff_down_b = tmp; + } } switch (ctx_clip.proj_type) { @@ -2022,12 +2028,14 @@ struct clip_model_loader { { vision_model.mm_model_adapter_conv_w = get_tensor(string_format(TN_GLM_ADAPER_CONV, "weight")); vision_model.mm_model_adapter_conv_b = get_tensor(string_format(TN_GLM_ADAPER_CONV, "bias")); - vision_model.mm_model_mlp_0_w = get_tensor(string_format(TN_GLM_ADAPTER_LINEAR,"weight")); - vision_model.mm_model_ln_q_w = get_tensor(string_format(TN_GLM_ADAPTER_NORM_1,"weight")); - vision_model.mm_model_ln_q_b = get_tensor(string_format(TN_GLM_ADAPTER_NORM_1,"bias")); - vision_model.mm_model_mlp_1_w = get_tensor(string_format(TN_GLM_ADAPTER_D_H_2_4H,"weight")); - vision_model.mm_model_mlp_2_w = get_tensor(string_format(TN_GLM_ADAPTER_GATE,"weight")); - vision_model.mm_model_mlp_3_w = get_tensor(string_format(TN_GLM_ADAPTER_D_4H_2_H,"weight")); + vision_model.mm_model_mlp_0_w = get_tensor(string_format(TN_GLM_ADAPTER_LINEAR, "weight")); + vision_model.mm_model_ln_q_w = get_tensor(string_format(TN_GLM_ADAPTER_NORM_1, "weight")); + vision_model.mm_model_ln_q_b = get_tensor(string_format(TN_GLM_ADAPTER_NORM_1, "bias")); + vision_model.mm_model_mlp_1_w = get_tensor(string_format(TN_GLM_ADAPTER_D_H_2_4H, "weight")); + vision_model.mm_model_mlp_2_w = get_tensor(string_format(TN_GLM_ADAPTER_GATE, "weight")); + vision_model.mm_model_mlp_3_w = get_tensor(string_format(TN_GLM_ADAPTER_D_4H_2_H, "weight")); + vision_model.mm_glm_tok_boi = get_tensor(string_format(TN_TOK_GLM_BOI, "weight")); + vision_model.mm_glm_tok_eoi = get_tensor(string_format(TN_TOK_GLM_EOI, "weight")); } break; case PROJECTOR_TYPE_QWEN2VL: case PROJECTOR_TYPE_QWEN25VL: @@ -3030,7 +3038,7 @@ int32_t clip_get_patch_size(const struct clip_ctx * ctx) { } int32_t clip_get_hidden_size(const struct clip_ctx * ctx) { - return ctx->vision_model.hparams.hidden_size; + return ctx->vision_model.hparams.n_embd; } const char * clip_patch_merge_type(const struct clip_ctx * ctx) { @@ -3089,6 +3097,7 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im if (ctx->proj_type == PROJECTOR_TYPE_LDP || ctx->proj_type == PROJECTOR_TYPE_LDPV2 || ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE) { n_patches /= 4; + n_patches += 2; // for BOI and EOI token embeddings } else if (ctx->proj_type == PROJECTOR_TYPE_MINICPMV) { if (ctx->minicpmv_version == 2) { n_patches = 96; diff --git a/tools/llava/clip.h b/tools/llava/clip.h index 9c3311c7d..e4e12bb7a 100644 --- a/tools/llava/clip.h +++ b/tools/llava/clip.h @@ -78,10 +78,10 @@ CLIP_API int clip_uhd_num_image_embeds_col(struct clip_ctx * ctx_clip); CLIP_API void clip_add_load_image_size(struct clip_ctx * ctx_clip, struct clip_image_size * load_image_size); CLIP_API struct clip_image_size * clip_get_load_image_size(struct clip_ctx * ctx_clip); -CLIP_API struct clip_image_size * clip_image_size_init(); -CLIP_API struct clip_image_u8 * clip_image_u8_init (); -CLIP_API struct clip_image_f32 * clip_image_f32_init(); -CLIP_API struct clip_image_f32_batch * clip_image_f32_batch_init(); // only used by libllava +CLIP_API struct clip_image_size * clip_image_size_init(void); +CLIP_API struct clip_image_u8 * clip_image_u8_init (void); +CLIP_API struct clip_image_f32 * clip_image_f32_init(void); +CLIP_API struct clip_image_f32_batch * clip_image_f32_batch_init(void); // only used by libllava // nx, ny are the output image dimensions CLIP_API unsigned char * clip_image_u8_get_data(struct clip_image_u8 * img, uint32_t * nx, uint32_t * ny); diff --git a/tools/llava/llava.cpp b/tools/llava/llava.cpp index da5c0fd88..1b51cd673 100644 --- a/tools/llava/llava.cpp +++ b/tools/llava/llava.cpp @@ -2,6 +2,7 @@ #include "llava.h" #include "llama.h" +#include "ggml-cpp.h" #include #include @@ -209,7 +210,10 @@ static bool clip_llava_handle_patches(clip_ctx * ctx_clip, std::vector struct ggml_tensor *flatten = ggml_view_2d(model.ctx, permuted_cont, clip_n_mmproj_embd(ctx_clip), num_patches_height * num_patches_width * num_patches_per_side * num_patches_per_side, size_ele * clip_n_mmproj_embd(ctx_clip), 0); // ggml_tensor_printf(flatten,"flatten",__LINE__,false,false); ggml_build_forward_expand(gf, flatten); - ggml_graph_compute_with_ctx(model.ctx, gf, 1); + + ggml_backend_ptr backend { ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr) }; + ggml_backend_graph_compute(backend.get(), gf); + struct ggml_tensor* result = ggml_graph_node(gf, -1); memcpy(image_embd_out, image_embd_v[0], clip_embd_nbytes(ctx_clip)); // main image as global context diff --git a/tools/llava/mtmd-cli.cpp b/tools/llava/mtmd-cli.cpp index 474e7c4f8..4977d5480 100644 --- a/tools/llava/mtmd-cli.cpp +++ b/tools/llava/mtmd-cli.cpp @@ -63,7 +63,7 @@ static void sigint_handler(int signo) { #endif struct mtmd_cli_context { - mtmd_context_ptr ctx_vision; + mtmd::context_ptr ctx_vision; common_init_result llama_init; llama_model * model; @@ -72,7 +72,7 @@ struct mtmd_cli_context { llama_batch batch; int n_batch; - std::vector bitmaps; + mtmd::bitmaps bitmaps; // note: we know that gemma3 template is "linear", meaning each turn is completely separated to another // so here we don't need to keep track of chat history @@ -92,6 +92,10 @@ struct mtmd_cli_context { batch = llama_batch_init(params.n_batch, 0, 1); n_batch = params.n_batch; + if (!model || !lctx) { + exit(1); + } + if (!llama_model_chat_template(model, nullptr) && params.chat_template.empty()) { LOG_ERR("Model does not have chat template.\n"); LOG_ERR(" For old llava models, you may need to use '--chat-template vicuna'\n"); @@ -115,12 +119,12 @@ struct mtmd_cli_context { void init_vision_context(common_params & params) { const char * clip_path = params.mmproj.path.c_str(); - ctx_vision.reset(mtmd_init_from_file(clip_path, model, mtmd_context_params{ - /* use_gpu */ params.mmproj_use_gpu, - /* timings */ true, - /* n_threads */ params.cpuparams.n_threads, - /* verbosity */ params.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO, - })); + mtmd_context_params mparams = mtmd_context_params_default(); + mparams.use_gpu = params.mmproj_use_gpu; + mparams.print_timings = true; + mparams.n_threads = params.cpuparams.n_threads; + mparams.verbosity = params.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO; + ctx_vision.reset(mtmd_init_from_file(clip_path, model, mparams)); if (!ctx_vision.get()) { LOG_ERR("Failed to load vision model from %s\n", clip_path); exit(1); @@ -139,11 +143,11 @@ struct mtmd_cli_context { } bool load_image(const std::string & fname) { - mtmd_bitmap bitmap; - if (mtmd_helper_bitmap_init_from_file(fname.c_str(), bitmap)) { + mtmd::bitmap bmp(mtmd_helper_bitmap_init_from_file(fname.c_str())); + if (!bmp.ptr) { return false; } - bitmaps.push_back(std::move(bitmap)); + bitmaps.entries.push_back(std::move(bmp)); return true; } }; @@ -193,27 +197,40 @@ static int eval_message(mtmd_cli_context & ctx, common_chat_msg & msg, bool add_ LOG_DBG("formatted_chat.prompt: %s\n", formatted_chat.prompt.c_str()); mtmd_input_text text; - text.text = formatted_chat.prompt; + text.text = formatted_chat.prompt.c_str(); text.add_special = add_bos; text.parse_special = true; - mtmd_input_chunks chunks; if (g_is_interrupted) return 0; - int32_t res = mtmd_tokenize(ctx.ctx_vision.get(), chunks, text, ctx.bitmaps); + mtmd::input_chunks chunks(mtmd_input_chunks_init()); + auto bitmaps_c_ptr = ctx.bitmaps.c_ptr(); + int32_t res = mtmd_tokenize(ctx.ctx_vision.get(), + chunks.ptr.get(), // output + &text, // text + bitmaps_c_ptr.data(), + bitmaps_c_ptr.size()); if (res != 0) { LOG_ERR("Unable to tokenize prompt, res = %d\n", res); return 1; } - ctx.bitmaps.clear(); + ctx.bitmaps.entries.clear(); - if (mtmd_helper_eval(ctx.ctx_vision.get(), ctx.lctx, chunks, ctx.n_past, 0, ctx.n_batch)) { + llama_pos new_n_past; + if (mtmd_helper_eval_chunks(ctx.ctx_vision.get(), + ctx.lctx, // lctx + chunks.ptr.get(), // chunks + ctx.n_past, // n_past + 0, // seq_id + ctx.n_batch, // n_batch + true, // logits_last + &new_n_past)) { LOG_ERR("Unable to eval prompt\n"); return 1; } - ctx.n_past += mtmd_helper_get_n_pos(chunks); + ctx.n_past = new_n_past; LOG("\n"); @@ -246,7 +263,7 @@ int main(int argc, char ** argv) { struct common_sampler * smpl = common_sampler_init(ctx.model, params.sampling); int n_predict = params.n_predict < 0 ? INT_MAX : params.n_predict; - // ctrl+C handling + // Ctrl+C handling { #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) struct sigaction sigint_action; diff --git a/tools/llava/mtmd.cpp b/tools/llava/mtmd.cpp index ac624fc4c..a2e6f91a3 100644 --- a/tools/llava/mtmd.cpp +++ b/tools/llava/mtmd.cpp @@ -12,6 +12,30 @@ #include #include +// represents raw image data, layout is RGBRGBRGB... +// length of data must be nx * ny * 3 +struct mtmd_bitmap { + uint32_t nx; + uint32_t ny; + std::vector data; + std::string id; // optional user-defined id, for ex: can be set to image hash, useful for KV cache tracking +}; + +struct mtmd_image_tokens_deleter { + void operator()(mtmd_image_tokens * val); // forward declaration +}; +using mtmd_image_tokens_ptr = std::unique_ptr; + +struct mtmd_input_chunk { + mtmd_input_chunk_type type; + std::vector tokens_text; + mtmd_image_tokens_ptr tokens_image; +}; + +struct mtmd_input_chunks { + std::vector entries; +}; + // slice template, used by some llava-uhd models to correctly place the special tokens around image embeddings // models not having it (llava-1.6) will process embeddings without any special tokens in-between enum mtmd_slice_tmpl { @@ -21,6 +45,16 @@ enum mtmd_slice_tmpl { // TODO @ngxson : add support for idefics (SmolVLM) }; +mtmd_context_params mtmd_context_params_default() { + mtmd_context_params params; + params.use_gpu = true; + params.print_timings = true; + params.n_threads = 4; + params.verbosity = GGML_LOG_LEVEL_INFO; + params.image_marker = MTMD_DEFAULT_IMAGE_MARKER; + return params; +} + struct mtmd_context { struct clip_ctx * ctx_clip; const struct llama_model * text_model; @@ -132,6 +166,16 @@ struct mtmd_image_tokens { uint32_t n_tokens() const { return nx * ny; } clip_image_f32_batch batch_f32; // preprocessed image patches std::string id; // optional user-defined ID, useful for KV cache tracking + + mtmd_image_tokens clone() { + return mtmd_image_tokens{ + nx, + ny, + use_mrope_pos, + batch_f32.clone(), + id + }; + } }; mtmd_context * mtmd_init_from_file(const char * mmproj_fname, @@ -172,12 +216,13 @@ static std::vector mtmd_tokenize_text_internal( } int32_t mtmd_tokenize(mtmd_context * ctx, - std::vector & output, - const mtmd_input_text & text, - const std::vector & bitmaps) { + mtmd_input_chunks * output, + const mtmd_input_text * text, + const mtmd_bitmap ** bitmaps, + size_t n_bitmaps) { auto vocab = llama_model_get_vocab(ctx->text_model); - std::string prompt_modified(text.text); + std::string prompt_modified(text->text); std::string marker_modified(ctx->image_marker); projector_type proj_type = clip_get_projector_type(ctx->ctx_clip); @@ -189,11 +234,6 @@ int32_t mtmd_tokenize(mtmd_context * ctx, marker_modified = "" + ctx->image_marker + ""; string_replace_all(prompt_modified, ctx->image_marker, marker_modified); - } else if (proj_type == PROJECTOR_TYPE_GLM_EDGE) { - // <|begin_of_image|> ... (image embeddings) ... <|end_of_image|> - marker_modified = "<|begin_of_image|>" + ctx->image_marker + "<|end_of_image|>"; - string_replace_all(prompt_modified, ctx->image_marker, marker_modified); - } else if (proj_type == PROJECTOR_TYPE_IDEFICS3) { // https://github.com/huggingface/transformers/blob/a42ba80fa520c784c8f11a973ca9034e5f859b79/src/transformers/models/idefics3/processing_idefics3.py#L192-L215 marker_modified = "" + ctx->image_marker + ""; @@ -213,10 +253,11 @@ int32_t mtmd_tokenize(mtmd_context * ctx, } // llava-1.5, llava-1.6, Yi-VL, Yi-34B, granite: don't need to add prefix and suffix + // for glm-edge, BOI and EOI token's embeddings are not present in the text model std::vector parts = string_split_str(prompt_modified, ctx->image_marker); - output.clear(); - output.reserve(parts.size()); + output->entries.clear(); + output->entries.reserve(parts.size()); size_t i_img = 0; @@ -227,7 +268,7 @@ int32_t mtmd_tokenize(mtmd_context * ctx, std::move(tokens), {}, }; - output.emplace_back(std::move(chunk)); + output->entries.emplace_back(std::move(chunk)); }; // utility for splitting batch of multiple images into chunks of batch having single images @@ -255,7 +296,7 @@ int32_t mtmd_tokenize(mtmd_context * ctx, for (const auto & part : parts) { // printf("tokenizing part: %s\n", part.c_str()); bool add_bos = &parts.front() == ∂ - auto tokens = mtmd_tokenize_text_internal(vocab, part, text.add_special && add_bos, text.parse_special); + auto tokens = mtmd_tokenize_text_internal(vocab, part, text->add_special && add_bos, text->parse_special); if (tokens.empty()) { continue; } @@ -264,22 +305,22 @@ int32_t mtmd_tokenize(mtmd_context * ctx, std::move(tokens), {}, }; - output.emplace_back(std::move(chunk)); + output->entries.emplace_back(std::move(chunk)); if (&parts.back() != &part) { // add image token to middle of 2 parts - if (i_img >= bitmaps.size()) { + if (i_img >= n_bitmaps) { LOG_ERR("%s: error: not enough images for %d parts\n", __func__, (int)parts.size()); return 1; } // convert mtmd_bitmap to clip_image_u8 clip_image_u8_ptr img_u8(clip_image_u8_init()); - img_u8->nx = bitmaps[i_img].nx; - img_u8->ny = bitmaps[i_img].ny; - img_u8->buf.resize(bitmaps[i_img].data.size()); - std::memcpy(img_u8->buf.data(), bitmaps[i_img].data.data(), img_u8->nx * img_u8->ny * 3); + img_u8->nx = bitmaps[i_img]->nx; + img_u8->ny = bitmaps[i_img]->ny; + img_u8->buf.resize(bitmaps[i_img]->data.size()); + std::memcpy(img_u8->buf.data(), bitmaps[i_img]->data.data(), img_u8->nx * img_u8->ny * 3); clip_image_size img_u8_size{img_u8->nx, img_u8->ny}; // preprocess image @@ -292,12 +333,12 @@ int32_t mtmd_tokenize(mtmd_context * ctx, if (ctx->slice_tmpl == MTMD_SLICE_TMPL_MINICPMV_2_5 || ctx->slice_tmpl == MTMD_SLICE_TMPL_MINICPMV_2_6) { // split batch into chunks of single images - auto chunks = split_batch_to_chunk(std::move(batch_f32), bitmaps[i_img].id); + auto chunks = split_batch_to_chunk(std::move(batch_f32), bitmaps[i_img]->id); GGML_ASSERT(chunks.size() > 0); // add overview image add_text_chunk({ctx->tok_ov_img_start}); - output.emplace_back(std::move(chunks.front())); + output->entries.emplace_back(std::move(chunks.front())); chunks.erase(chunks.begin()); add_text_chunk({ctx->tok_ov_img_end}); @@ -315,7 +356,7 @@ int32_t mtmd_tokenize(mtmd_context * ctx, if (ctx->tok_sli_img_start != LLAMA_TOKEN_NULL) { add_text_chunk({ctx->tok_sli_img_start}); } - output.emplace_back(std::move(chunks[y * n_col + x])); + output->entries.emplace_back(std::move(chunks[y * n_col + x])); if (ctx->tok_sli_img_end != LLAMA_TOKEN_NULL) { add_text_chunk({ctx->tok_sli_img_end}); } @@ -347,7 +388,7 @@ int32_t mtmd_tokenize(mtmd_context * ctx, image_tokens->ny = 1; } image_tokens->batch_f32 = std::move(batch_f32); - image_tokens->id = bitmaps[i_img].id; // optional + image_tokens->id = bitmaps[i_img]->id; // optional LOG_DBG("image_tokens->nx = %d\n", image_tokens->nx); LOG_DBG("image_tokens->ny = %d\n", image_tokens->ny); @@ -358,7 +399,7 @@ int32_t mtmd_tokenize(mtmd_context * ctx, {}, std::move(image_tokens), }; - output.emplace_back(std::move(chunk)); + output->entries.emplace_back(std::move(chunk)); } i_img++; // move to next image @@ -368,35 +409,12 @@ int32_t mtmd_tokenize(mtmd_context * ctx, return 0; } -void mtmd_image_tokens_free(mtmd_image_tokens * image_tokens) { +static void mtmd_image_tokens_free(mtmd_image_tokens * image_tokens) { if (image_tokens) { delete image_tokens; } } -size_t mtmd_image_tokens_get_n_tokens(const mtmd_image_tokens * image_tokens) { - return image_tokens->n_tokens(); -} - -size_t mtmd_image_tokens_get_nx(const mtmd_image_tokens * image_tokens) { - return image_tokens->nx; -} - -size_t mtmd_image_tokens_get_ny(const mtmd_image_tokens * image_tokens) { - return image_tokens->ny; -} - -std::string mtmd_image_tokens_get_id(const mtmd_image_tokens * image_tokens) { - return image_tokens->id; -} - -llama_pos mtmd_image_tokens_get_n_pos(const mtmd_image_tokens * image_tokens) { - if (image_tokens->use_mrope_pos) { - return 1; // for M-RoPE, the whole image is 1 in temporal dimension - } - return image_tokens->n_tokens(); -} - int32_t mtmd_encode(mtmd_context * ctx, const mtmd_image_tokens * image_tokens) { int n_mmproj_embd = clip_n_mmproj_embd(ctx->ctx_clip); ctx->image_embd_v.resize(image_tokens->n_tokens() * n_mmproj_embd); @@ -436,13 +454,18 @@ float * mtmd_get_output_embd(mtmd_context * ctx) { return ctx->image_embd_v.data(); } -size_t mtmd_helper_get_n_tokens(mtmd_input_chunks & chunks) { +size_t mtmd_helper_get_n_tokens(const mtmd_input_chunks * chunks) { size_t n_tokens = 0; - for (auto & chunk : chunks) { - if (chunk.type == MTMD_INPUT_CHUNK_TYPE_TEXT) { - n_tokens += chunk.tokens_text.size(); - } else if (chunk.type == MTMD_INPUT_CHUNK_TYPE_IMAGE) { - n_tokens += mtmd_image_tokens_get_n_tokens(chunk.tokens_image.get()); + for (size_t i = 0; i < mtmd_input_chunks_size(chunks); i++) { + auto chunk = mtmd_input_chunks_get(chunks, i); + auto chunk_type = mtmd_input_chunk_get_type(chunk); + if (chunk_type == MTMD_INPUT_CHUNK_TYPE_TEXT) { + size_t n_tokens_text; + mtmd_input_chunk_get_tokens_text(chunk, &n_tokens_text); + n_tokens += n_tokens_text; + } else if (chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE) { + auto tokens_image = mtmd_input_chunk_get_tokens_image(chunk); + n_tokens += mtmd_image_tokens_get_n_tokens(tokens_image); } else { GGML_ASSERT(false && "chunk type not supported"); } @@ -450,13 +473,18 @@ size_t mtmd_helper_get_n_tokens(mtmd_input_chunks & chunks) { return n_tokens; } -llama_pos mtmd_helper_get_n_pos(mtmd_input_chunks & chunks) { +llama_pos mtmd_helper_get_n_pos(const mtmd_input_chunks * chunks) { llama_pos n_pos = 0; - for (auto & chunk : chunks) { - if (chunk.type == MTMD_INPUT_CHUNK_TYPE_TEXT) { - n_pos += chunk.tokens_text.size(); - } else if (chunk.type == MTMD_INPUT_CHUNK_TYPE_IMAGE) { - n_pos += mtmd_image_tokens_get_n_pos(chunk.tokens_image.get()); + for (size_t i = 0; i < mtmd_input_chunks_size(chunks); i++) { + auto chunk = mtmd_input_chunks_get(chunks, i); + auto chunk_type = mtmd_input_chunk_get_type(chunk); + if (chunk_type == MTMD_INPUT_CHUNK_TYPE_TEXT) { + size_t n_tokens_text; + mtmd_input_chunk_get_tokens_text(chunk, &n_tokens_text); + n_pos += n_tokens_text; + } else if (chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE) { + auto tokens_image = mtmd_input_chunk_get_tokens_image(chunk); + n_pos += mtmd_image_tokens_get_n_pos(tokens_image); } else { GGML_ASSERT(false && "chunk type not supported"); } @@ -552,143 +580,172 @@ struct decode_embd_batch { } }; -int32_t mtmd_helper_eval(mtmd_context * ctx, - llama_context * lctx, - mtmd_input_chunks & chunks, - llama_pos pos0, +int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx, + struct llama_context * lctx, + const mtmd_input_chunk * chunk, + llama_pos n_past, llama_seq_id seq_id, - int32_t n_batch) { + int32_t n_batch, + bool logits_last, + llama_pos * new_n_past) { int32_t ret; - llama_pos n_past = pos0; llama_batch text_batch = llama_batch_init(n_batch, 0, 1); + auto chunk_type = mtmd_input_chunk_get_type(chunk); int n_mmproj_embd = clip_n_mmproj_embd(ctx->ctx_clip); int n_pos_per_embd = mtmd_decode_use_mrope(ctx) ? 4 : 1; - for (auto & chunk : chunks) { - bool is_last = &chunk == &chunks.back(); - if (chunk.type == MTMD_INPUT_CHUNK_TYPE_TEXT) { - text_batch.n_tokens = chunk.tokens_text.size(); - size_t i = 0; - while (i < chunk.tokens_text.size()) { // split into batches - for (; i < chunk.tokens_text.size() && text_batch.n_tokens < n_batch; i++) { - text_batch.token [i] = chunk.tokens_text[i]; - text_batch.pos [i] = n_past++; - text_batch.n_seq_id[i] = 1; - text_batch.seq_id [i][0] = seq_id; - text_batch.logits [i] = false; - } - if (is_last) { - // always get logits for last input chunk - text_batch.logits[text_batch.n_tokens - 1] = true; - } - ret = llama_decode(lctx, text_batch); - if (ret != 0) { - LOG_ERR("failed to decode text\n"); - llama_batch_free(text_batch); - return ret; - } + if (chunk_type == MTMD_INPUT_CHUNK_TYPE_TEXT) { + size_t n_tokens; + const auto tokens = mtmd_input_chunk_get_tokens_text(chunk, &n_tokens); + LOG_DBG("decoding text chunk, n_tokens = %zu\n", n_tokens); + size_t i = 0; + while (i < n_tokens) { // split into batches + text_batch.n_tokens = 0; // clear the batch + for (; i < n_tokens && text_batch.n_tokens < n_batch; i++) { + text_batch.n_tokens++; + text_batch.token [i] = tokens[i]; + text_batch.pos [i] = n_past++; + text_batch.n_seq_id[i] = 1; + text_batch.seq_id [i][0] = seq_id; + text_batch.logits [i] = false; } - - } else if (chunk.type == MTMD_INPUT_CHUNK_TYPE_IMAGE) { - GGML_ASSERT(!is_last && "logits for last image chunk is not yet supported"); - GGML_ASSERT(chunk.tokens_image != nullptr); - int64_t t0 = ggml_time_ms(); - if (ctx->print_timings) { - LOG_INF("encoding image or slice...\n"); + bool is_last_token = (i == n_tokens); + if (logits_last && is_last_token) { + text_batch.logits[text_batch.n_tokens - 1] = true; } - ret = mtmd_encode(ctx, chunk.tokens_image.get()); + ret = llama_decode(lctx, text_batch); if (ret != 0) { - LOG_ERR("failed to encode image\n"); + LOG_ERR("failed to decode text\n"); llama_batch_free(text_batch); return ret; } - if (ctx->print_timings) { - LOG_INF("image/slice encoded in %" PRId64 " ms\n", ggml_time_ms() - t0); - } - - int32_t n_tokens = mtmd_image_tokens_get_n_tokens(chunk.tokens_image.get()); - int32_t i_batch = 0; - int32_t n_img_batches = GGML_PAD(n_tokens, n_batch) / n_batch; - float * embd = mtmd_get_output_embd(ctx); - decode_embd_batch batch_embd(embd, n_tokens, n_pos_per_embd, n_mmproj_embd); - - const int nx = mtmd_image_tokens_get_nx(chunk.tokens_image.get()); - const int ny = mtmd_image_tokens_get_ny(chunk.tokens_image.get()); - - if (mtmd_decode_use_mrope(ctx)) { - batch_embd.set_position_mrope(n_past, nx, ny, seq_id); - } else { - batch_embd.set_position_normal(n_past, seq_id); - } - - if (mtmd_decode_use_non_causal(ctx)) { - llama_set_causal_attn(lctx, false); - // TODO @ngxson : need to make sure only one image is processed at a time, and n_ubatch must be enough to hold the image - } - - while (i_batch < n_img_batches) { // split into batches - int pos_offset = i_batch*n_batch; - int n_tokens_batch = std::min(n_batch, n_tokens - pos_offset); - llama_batch batch_embd_view = batch_embd.get_view(pos_offset, n_tokens_batch); - - LOG_INF("decoding image batch %d/%d, n_tokens_batch = %d\n", i_batch+1, n_img_batches, n_tokens_batch); - - int64_t t1 = ggml_time_ms(); - ret = llama_decode(lctx, batch_embd_view); - if (ret != 0) { - LOG_ERR("failed to decode image\n"); - llama_set_causal_attn(lctx, true); // restore causal attn - llama_batch_free(text_batch); - return ret; - } - - if (ctx->print_timings) { - LOG_INF("image decoded (batch %d/%d) in %" PRId64 " ms\n", i_batch+1, n_img_batches, ggml_time_ms() - t1); - } - - i_batch++; - } - - // for mrope, one image is one single **temporal** position - n_past += mtmd_decode_use_mrope(ctx) ? 1 : n_tokens; - - if (mtmd_decode_use_non_causal(ctx)) { - llama_set_causal_attn(lctx, true); - } - - } else { - GGML_ASSERT(false && "chunk type not supported"); + *new_n_past += text_batch.n_tokens; } + + } else if (chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE) { + const auto image_tokens = mtmd_input_chunk_get_tokens_image(chunk); + int64_t t0 = ggml_time_ms(); + if (ctx->print_timings) { + LOG_INF("encoding image or slice...\n"); + } + ret = mtmd_encode(ctx, image_tokens); + if (ret != 0) { + LOG_ERR("failed to encode image\n"); + llama_batch_free(text_batch); + return ret; + } + if (ctx->print_timings) { + LOG_INF("image/slice encoded in %" PRId64 " ms\n", ggml_time_ms() - t0); + } + + int32_t n_tokens = mtmd_image_tokens_get_n_tokens(image_tokens); + int32_t i_batch = 0; + int32_t n_img_batches = GGML_PAD(n_tokens, n_batch) / n_batch; + float * embd = mtmd_get_output_embd(ctx); + decode_embd_batch batch_embd(embd, n_tokens, n_pos_per_embd, n_mmproj_embd); + + const int nx = mtmd_image_tokens_get_nx(image_tokens); + const int ny = mtmd_image_tokens_get_ny(image_tokens); + + if (mtmd_decode_use_mrope(ctx)) { + batch_embd.set_position_mrope(n_past, nx, ny, seq_id); + } else { + batch_embd.set_position_normal(n_past, seq_id); + } + + if (mtmd_decode_use_non_causal(ctx)) { + llama_set_causal_attn(lctx, false); + // TODO @ngxson : need to make sure only one image is processed at a time, and n_ubatch must be enough to hold the image + } + + while (i_batch < n_img_batches) { // split into batches + int pos_offset = i_batch*n_batch; + int n_tokens_batch = std::min(n_batch, n_tokens - pos_offset); + llama_batch batch_embd_view = batch_embd.get_view(pos_offset, n_tokens_batch); + + LOG_INF("decoding image batch %d/%d, n_tokens_batch = %d\n", i_batch+1, n_img_batches, n_tokens_batch); + + int64_t t1 = ggml_time_ms(); + ret = llama_decode(lctx, batch_embd_view); + if (ret != 0) { + LOG_ERR("failed to decode image\n"); + llama_set_causal_attn(lctx, true); // restore causal attn + llama_batch_free(text_batch); + return ret; + } + + if (ctx->print_timings) { + LOG_INF("image decoded (batch %d/%d) in %" PRId64 " ms\n", i_batch+1, n_img_batches, ggml_time_ms() - t1); + } + + i_batch++; + } + + n_past += mtmd_image_tokens_get_n_pos(image_tokens); + *new_n_past = n_past; + + if (mtmd_decode_use_non_causal(ctx)) { + llama_set_causal_attn(lctx, true); + } + + } else { + GGML_ABORT("chunk type not supported"); } - llama_batch_free(text_batch); return 0; } -int32_t mtmd_helper_bitmap_init_from_buf(const unsigned char * buf, size_t len, mtmd_bitmap & output) { +int32_t mtmd_helper_eval_chunks(mtmd_context * ctx, + struct llama_context * lctx, + const mtmd_input_chunks * chunks, + llama_pos n_past, + llama_seq_id seq_id, + int32_t n_batch, + bool logits_last, + llama_pos * new_n_past) { + size_t n_chunks = mtmd_input_chunks_size(chunks); + if (n_chunks == 0) { + LOG_WRN("no chunks to eval\n"); + return 0; + } + + for (size_t i = 0; i < n_chunks; i++) { + bool chunk_logits_last = (i == n_chunks - 1) && logits_last; + auto chunk = mtmd_input_chunks_get(chunks, i); + + int32_t res = mtmd_helper_eval_chunk_single(ctx, lctx, chunk, n_past, seq_id, n_batch, chunk_logits_last, &n_past); + if (res != 0) { + LOG_ERR("failed to eval chunk %zu\n", i); + return res; + } + *new_n_past = n_past; + } + + return 0; +} + +mtmd_bitmap * mtmd_helper_bitmap_init_from_buf(const unsigned char * buf, size_t len) { clip_image_u8_ptr img_u8(clip_image_u8_init()); bool ok = clip_image_load_from_bytes(buf, len, img_u8.get(),2048); if (!ok) { LOG_ERR("Unable to load image from buffer\n"); - return 1; + return nullptr; } - unsigned char * data = clip_image_u8_get_data(img_u8.get(), &output.nx, &output.ny); - output.data.resize(output.nx * output.ny * 3); - std::memcpy(output.data.data(), data, output.nx * output.ny * 3); - return 0; + uint32_t nx, ny; + unsigned char * data = clip_image_u8_get_data(img_u8.get(), &nx, &ny); + return mtmd_bitmap_init(nx, ny, data); } -int32_t mtmd_helper_bitmap_init_from_file(const char * fname, mtmd_bitmap & output) { +mtmd_bitmap * mtmd_helper_bitmap_init_from_file(const char * fname) { clip_image_u8_ptr img_u8(clip_image_u8_init()); bool ok = clip_image_load_from_file(fname, img_u8.get()); if (!ok) { LOG_ERR("Unable to load image %s\n", fname); - return 1; + return nullptr; } - unsigned char * data = clip_image_u8_get_data(img_u8.get(), &output.nx, &output.ny); - output.data.resize(output.nx * output.ny * 3); - std::memcpy(output.data.data(), data, output.nx * output.ny * 3); - return 0; + uint32_t nx, ny; + unsigned char * data = clip_image_u8_get_data(img_u8.get(), &nx, &ny); + return mtmd_bitmap_init(nx, ny, data); } bool mtmd_decode_use_non_causal(mtmd_context * ctx) { @@ -706,3 +763,175 @@ bool mtmd_decode_use_mrope(mtmd_context * ctx) { void mtmd_image_tokens_deleter::operator()(mtmd_image_tokens * val) { mtmd_image_tokens_free(val); } + + +// +// public API functions +// + +// mtmd_bitmap + +mtmd_bitmap * mtmd_bitmap_init(uint32_t nx, + uint32_t ny, + const unsigned char * data) { + mtmd_bitmap * bitmap = new mtmd_bitmap; + bitmap->nx = nx; + bitmap->ny = ny; + size_t data_size = (size_t)nx * ny * 3; + bitmap->data.resize(data_size); + std::memcpy(bitmap->data.data(), data, data_size); + return bitmap; +} + +uint32_t mtmd_bitmap_get_nx(const mtmd_bitmap * bitmap) { + return bitmap->nx; +} + +uint32_t mtmd_bitmap_get_ny(const mtmd_bitmap * bitmap) { + return bitmap->ny; +} + +const unsigned char * mtmd_bitmap_get_data(const mtmd_bitmap * bitmap) { + return bitmap->data.data(); +} + +const char * mtmd_bitmap_get_id(const mtmd_bitmap * bitmap) { + return bitmap->id.c_str(); +} + +void mtmd_bitmap_set_id(mtmd_bitmap * bitmap, const char * id) { + if (id) { + bitmap->id = std::string(id); + } else { + bitmap->id.clear(); + } +} + +void mtmd_bitmap_free(mtmd_bitmap * bitmap) { + if (bitmap) { + delete bitmap; + } +} + +// mtmd_input_chunks + +mtmd_input_chunks * mtmd_input_chunks_init() { + return new mtmd_input_chunks; +} + +size_t mtmd_input_chunks_size(const mtmd_input_chunks * chunks) { + return chunks->entries.size(); +} + +const mtmd_input_chunk * mtmd_input_chunks_get(const mtmd_input_chunks * chunks, size_t idx) { + if (idx >= chunks->entries.size()) { + return nullptr; + } + return &chunks->entries[idx]; +} + +void mtmd_input_chunks_free(mtmd_input_chunks * chunks) { + if (chunks) { + delete chunks; + } +} + +// mtmd_input_chunk + +enum mtmd_input_chunk_type mtmd_input_chunk_get_type(const mtmd_input_chunk * chunk) { + return chunk->type; +} + +const llama_token * mtmd_input_chunk_get_tokens_text(const mtmd_input_chunk * chunk, size_t * n_tokens_output) { + if (chunk->type == MTMD_INPUT_CHUNK_TYPE_TEXT) { + *n_tokens_output = chunk->tokens_text.size(); + return chunk->tokens_text.data(); + } + *n_tokens_output = 0; + return nullptr; +} + +const mtmd_image_tokens * mtmd_input_chunk_get_tokens_image(const mtmd_input_chunk * chunk) { + if (chunk->type == MTMD_INPUT_CHUNK_TYPE_IMAGE) { + return chunk->tokens_image.get(); + } + return nullptr; +} + +mtmd_input_chunk * mtmd_input_chunk_copy(const mtmd_input_chunk * chunk) { + mtmd_input_chunk * copy = new mtmd_input_chunk{ + chunk->type, + chunk->tokens_text, + mtmd_image_tokens_ptr(), + }; + if (chunk->tokens_image) { + // copy the image tokens + copy->tokens_image = mtmd_image_tokens_ptr(new mtmd_image_tokens()); + *copy->tokens_image = chunk->tokens_image->clone(); + } + return copy; +} + +void mtmd_input_chunk_free(mtmd_input_chunk * chunk) { + if (chunk) { + delete chunk; + } +} + +// mtmd_image_tokens + +size_t mtmd_image_tokens_get_n_tokens(const mtmd_image_tokens * image_tokens) { + return image_tokens->n_tokens(); +} + +size_t mtmd_image_tokens_get_nx(const mtmd_image_tokens * image_tokens) { + return image_tokens->nx; +} + +size_t mtmd_image_tokens_get_ny(const mtmd_image_tokens * image_tokens) { + return image_tokens->ny; +} + +const char * mtmd_image_tokens_get_id(const mtmd_image_tokens * image_tokens) { + return image_tokens->id.c_str(); +} + +llama_pos mtmd_image_tokens_get_n_pos(const mtmd_image_tokens * image_tokens) { + if (image_tokens->use_mrope_pos) { + return 1; // for M-RoPE, the whole image is 1 in temporal dimension + } + return image_tokens->n_tokens(); +} + +// test function + +mtmd_input_chunks * mtmd_test_create_input_chunks() { + mtmd_input_chunks * chunks = mtmd_input_chunks_init(); + if (!chunks) { + return nullptr; + } + + // create a text chunk + std::vector tokens_text = { 1, 2, 3, 4, 5 }; + mtmd_input_chunk chunk_text{ + MTMD_INPUT_CHUNK_TYPE_TEXT, + std::move(tokens_text), + {}, + }; + chunks->entries.emplace_back(std::move(chunk_text)); + + // create an image chunk + mtmd_image_tokens_ptr image_tokens(new mtmd_image_tokens); + image_tokens->nx = 4; + image_tokens->ny = 4; + image_tokens->batch_f32.entries.resize(16); + image_tokens->id = "image_1"; + mtmd_input_chunk chunk_image{ + MTMD_INPUT_CHUNK_TYPE_IMAGE, + {}, + std::move(image_tokens), + }; + chunks->entries.emplace_back(std::move(chunk_image)); + + return chunks; +} diff --git a/tools/llava/mtmd.h b/tools/llava/mtmd.h index 6805e5e48..e2f76e2e8 100644 --- a/tools/llava/mtmd.h +++ b/tools/llava/mtmd.h @@ -5,9 +5,24 @@ #include "llama.h" #include "clip.h" +#include +#include +#include + +#ifdef __cplusplus #include #include #include +#endif + +/** + * libmtmd: A library for multimodal support in llama.cpp. + * + * WARNING: This API is experimental and subject to many BREAKING CHANGES. + * Issues related to API usage may receive lower priority support. + * + * For the usage, see an example in mtmd-cli.cpp + */ #ifdef LLAMA_SHARED # if defined(_WIN32) && !defined(__MINGW32__) @@ -23,60 +38,118 @@ # define MTMD_API #endif +#define MTMD_DEFAULT_IMAGE_MARKER "<__image__>" + #ifdef __cplusplus +extern "C" { +#endif enum mtmd_input_chunk_type { MTMD_INPUT_CHUNK_TYPE_TEXT, MTMD_INPUT_CHUNK_TYPE_IMAGE, }; +// opaque types struct mtmd_context; +struct mtmd_bitmap; struct mtmd_image_tokens; - -// represents raw image data, layout is RGBRGBRGB... -// length of data must be nx * ny * 3 -struct mtmd_bitmap { - uint32_t nx; - uint32_t ny; - std::vector data; - std::string id; // optional user-defined id, for ex: can be set to image hash, useful for KV cache tracking -}; - -struct mtmd_image_tokens_deleter { - void operator()(mtmd_image_tokens * val); // forward declaration -}; -using mtmd_image_tokens_ptr = std::unique_ptr; - -struct mtmd_input_chunk { - mtmd_input_chunk_type type; - std::vector tokens_text; - mtmd_image_tokens_ptr tokens_image; -}; - -using mtmd_input_chunks = std::vector; - -struct mtmd_context_params { - bool use_gpu = true; - bool print_timings = true; - int n_threads = 4; - enum ggml_log_level verbosity = GGML_LOG_LEVEL_INFO; - const char * image_marker = "<__image__>"; -}; +struct mtmd_input_chunk; +struct mtmd_input_chunks; struct mtmd_input_text { - std::string text; + const char * text; bool add_special; bool parse_special; }; +// +// C API +// + +typedef struct mtmd_context mtmd_context; +typedef struct mtmd_bitmap mtmd_bitmap; +typedef struct mtmd_image_tokens mtmd_image_tokens; +typedef struct mtmd_input_chunk mtmd_input_chunk; +typedef struct mtmd_input_chunks mtmd_input_chunks; +typedef struct mtmd_input_text mtmd_input_text; + +struct mtmd_context_params { + bool use_gpu; + bool print_timings; + int n_threads; + enum ggml_log_level verbosity; + const char * image_marker; +}; + +MTMD_API struct mtmd_context_params mtmd_context_params_default(void); + // initialize the mtmd context // return nullptr on failure MTMD_API mtmd_context * mtmd_init_from_file(const char * mmproj_fname, - const llama_model * text_model, - const mtmd_context_params ctx_params); + const struct llama_model * text_model, + const struct mtmd_context_params ctx_params); MTMD_API void mtmd_free(mtmd_context * ctx); +// whether we need to set non-causal mask before llama_decode +MTMD_API bool mtmd_decode_use_non_causal(mtmd_context * ctx); + +// whether the current model use M-RoPE for llama_decode +MTMD_API bool mtmd_decode_use_mrope(mtmd_context * ctx); + + +// mtmd_bitmap +// +// length of data must be nx * ny * 3 +// the data is in RGBRGBRGB... format +MTMD_API mtmd_bitmap * mtmd_bitmap_init (uint32_t nx, + uint32_t ny, + const unsigned char * data); +MTMD_API uint32_t mtmd_bitmap_get_nx (const mtmd_bitmap * bitmap); +MTMD_API uint32_t mtmd_bitmap_get_ny (const mtmd_bitmap * bitmap); +MTMD_API const unsigned char * mtmd_bitmap_get_data(const mtmd_bitmap * bitmap); +MTMD_API void mtmd_bitmap_free (mtmd_bitmap * bitmap); +// bitmap ID is optional, but useful for KV cache tracking +// these getters/setters are dedicated functions, so you can for example calculate the hash of the image based on mtmd_bitmap_get_data() +MTMD_API const char * mtmd_bitmap_get_id(const mtmd_bitmap * bitmap); +MTMD_API void mtmd_bitmap_set_id(mtmd_bitmap * bitmap, const char * id); + + +// mtmd_input_chunks +// +// this is simply a list of mtmd_input_chunk +// the elements can only be populated via mtmd_tokenize() +MTMD_API mtmd_input_chunks * mtmd_input_chunks_init(void); +MTMD_API size_t mtmd_input_chunks_size(const mtmd_input_chunks * chunks); +MTMD_API const mtmd_input_chunk * mtmd_input_chunks_get (const mtmd_input_chunks * chunks, size_t idx); +MTMD_API void mtmd_input_chunks_free(mtmd_input_chunks * chunks); + +// mtmd_input_chunk +// +// the instance will be constructed via mtmd_tokenize() +// it will be freed along with mtmd_input_chunks +MTMD_API enum mtmd_input_chunk_type mtmd_input_chunk_get_type (const mtmd_input_chunk * chunk); +MTMD_API const llama_token * mtmd_input_chunk_get_tokens_text (const mtmd_input_chunk * chunk, size_t * n_tokens_output); +MTMD_API const mtmd_image_tokens * mtmd_input_chunk_get_tokens_image(const mtmd_input_chunk * chunk); + +// in case you want to use custom logic to handle the chunk (i.e. KV cache management) +// you can move the chunk ownership to your own code by copying it +// remember to free the chunk when you are done with it +MTMD_API mtmd_input_chunk * mtmd_input_chunk_copy(const mtmd_input_chunk * chunk); +MTMD_API void mtmd_input_chunk_free(mtmd_input_chunk * chunk); + + +// mtmd_image_tokens +// +// the instance will be constructed via mtmd_tokenize() +// it will be freed along with mtmd_input_chunk +MTMD_API size_t mtmd_image_tokens_get_n_tokens(const mtmd_image_tokens * image_tokens); +MTMD_API size_t mtmd_image_tokens_get_nx (const mtmd_image_tokens * image_tokens); +MTMD_API size_t mtmd_image_tokens_get_ny (const mtmd_image_tokens * image_tokens); +MTMD_API const char * mtmd_image_tokens_get_id (const mtmd_image_tokens * image_tokens); +// number of temporal positions (always 1 for M-RoPE, n_tokens otherwise) +MTMD_API llama_pos mtmd_image_tokens_get_n_pos (const mtmd_image_tokens * image_tokens); + // tokenize an input text prompt and an image // the prompt must have the input image marker (default: "<__image__>") in it // the marker will be replaced with the image tokens @@ -93,75 +166,152 @@ MTMD_API void mtmd_free(mtmd_context * ctx); // 1 on number of images not matching the number of markers // 2 on image preprocessing error MTMD_API int32_t mtmd_tokenize(mtmd_context * ctx, - std::vector & output, - const mtmd_input_text & text, - const std::vector & bitmaps); - -// access mtmd_image_tokens -MTMD_API size_t mtmd_image_tokens_get_n_tokens(const mtmd_image_tokens * image_tokens); -MTMD_API size_t mtmd_image_tokens_get_nx(const mtmd_image_tokens * image_tokens); -MTMD_API size_t mtmd_image_tokens_get_ny(const mtmd_image_tokens * image_tokens); -MTMD_API std::string mtmd_image_tokens_get_id(const mtmd_image_tokens * image_tokens); -MTMD_API llama_pos mtmd_image_tokens_get_n_pos(const mtmd_image_tokens * image_tokens); // number of temporal positions (always 1 for M-RoPE, n_tokens otherwise) -MTMD_API void mtmd_image_tokens_free(mtmd_image_tokens * image_tokens); + mtmd_input_chunks * output, + const mtmd_input_text * text, + const mtmd_bitmap ** bitmaps, + size_t n_bitmaps); // returns 0 on success MTMD_API int32_t mtmd_encode(mtmd_context * ctx, - const mtmd_image_tokens * image_tokens); + const mtmd_image_tokens * image_tokens); // get output embeddings from the last encode pass MTMD_API float * mtmd_get_output_embd(mtmd_context * ctx); -// whether we need to set non-causal mask before llama_decode -MTMD_API bool mtmd_decode_use_non_causal(mtmd_context * ctx); - -// whether the current model use M-RoPE for llama_decode -MTMD_API bool mtmd_decode_use_mrope(mtmd_context * ctx); - - +///////////////////////////////////////// // -// helper functions (can be implemented based on other functions) +// Helper functions (can be implemented based on other functions) // +// Please note that these helpers are not guaranteed to be stable. +// BREAKING CHANGES are expected. +// + +// helper function to construct a mtmd_bitmap from a file +// returns nullptr on failure +// this function is thread-safe +MTMD_API mtmd_bitmap * mtmd_helper_bitmap_init_from_file(const char * fname); + +// helper function to construct a mtmd_bitmap from a buffer containing a file +// the file content must be an image in format supported by stb_image (jpg, png, bmp, gif, etc.) +// returns nullptr on failure +// this function is thread-safe +MTMD_API mtmd_bitmap * mtmd_helper_bitmap_init_from_buf(const unsigned char * buf, size_t len); // helper to count the total number of tokens from a list of chunks, useful to keep track of KV cache -MTMD_API size_t mtmd_helper_get_n_tokens(mtmd_input_chunks & chunks); +MTMD_API size_t mtmd_helper_get_n_tokens(const mtmd_input_chunks * chunks); // helper to count the total position of tokens from a list of chunks, useful to keep track of n_past -MTMD_API llama_pos mtmd_helper_get_n_pos(mtmd_input_chunks & chunks); +// normally, n_pos is equal to n_tokens, but for M-RoPE it is different +MTMD_API llama_pos mtmd_helper_get_n_pos(const mtmd_input_chunks * chunks); // helper function that automatically: // 1. run llama_decode() on text chunks // 2. run mtmd_encode() on image chunks, then mtmd_get_output_embd() and then llama_decode() // if any of the mtmd_encode() or llama_decode() calls return non-zero, stop and forward the error // otherwise, returns 0 on success -MTMD_API int32_t mtmd_helper_eval(mtmd_context * ctx, - llama_context * lctx, - mtmd_input_chunks & chunks, - llama_pos pos0, - llama_seq_id seq_id, - int32_t n_batch); +// this function is NOT thread-safe +MTMD_API int32_t mtmd_helper_eval_chunks(mtmd_context * ctx, + struct llama_context * lctx, + const mtmd_input_chunks * chunks, + llama_pos n_past, + llama_seq_id seq_id, + int32_t n_batch, + bool logits_last, + llama_pos * new_n_past); -// helper function to construct a mtmd_bitmap from a file -// returns 0 on success -// this function is thread-safe -MTMD_API int32_t mtmd_helper_bitmap_init_from_file(const char * fname, mtmd_bitmap & output); +// works like mtmd_helper_eval_chunks(), but only for a single chunk +// this function is NOT thread-safe +MTMD_API int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx, + struct llama_context * lctx, + const mtmd_input_chunk * chunk, + llama_pos n_past, + llama_seq_id seq_id, + int32_t n_batch, + bool logits_last, + llama_pos * new_n_past); -// helper function to construct a mtmd_bitmap from a buffer -// the buffer must be an image in format supported by stb_image (jpg, png, bmp, gif, etc.) -// returns 0 on success -// this function is thread-safe -MTMD_API int32_t mtmd_helper_bitmap_init_from_buf(const unsigned char * buf, size_t len, mtmd_bitmap & output); +///////////////////////////////////////// + +// test function, to be used in test-mtmd-c-api.c +MTMD_API mtmd_input_chunks * mtmd_test_create_input_chunks(void); + +#ifdef __cplusplus +} // extern "C" +#endif + +// +// C++ wrappers +// + +#ifdef __cplusplus + +namespace mtmd { -// convenient unique_ptr wrappers struct mtmd_context_deleter { void operator()(mtmd_context * val) { mtmd_free(val); } }; -using mtmd_context_ptr = std::unique_ptr; +using context_ptr = std::unique_ptr; -#else +struct mtmd_bitmap_deleter { + void operator()(mtmd_bitmap * val) { mtmd_bitmap_free(val); } +}; +using bitmap_ptr = std::unique_ptr; -static_assert(false && "C header is not yet supported by this library"); +struct mtmd_input_chunks_deleter { + void operator()(mtmd_input_chunks * val) { mtmd_input_chunks_free(val); } +}; +using input_chunks_ptr = std::unique_ptr; + +struct mtmd_input_chunk_deleter { + void operator()(mtmd_input_chunk * val) { mtmd_input_chunk_free(val); } +}; +using input_chunk_ptr = std::unique_ptr; + +struct bitmap { + bitmap_ptr ptr; + bitmap() : ptr(nullptr) {} + bitmap(mtmd_bitmap * bitmap) : ptr(bitmap) {} + bitmap(bitmap && other) noexcept : ptr(std::move(other.ptr)) {} + bitmap(uint32_t nx, uint32_t ny, const unsigned char * data) { + ptr.reset(mtmd_bitmap_init(nx, ny, data)); + } + ~bitmap() = default; + uint32_t nx() { return mtmd_bitmap_get_nx(ptr.get()); } + uint32_t ny() { return mtmd_bitmap_get_ny(ptr.get()); } + const unsigned char * data() { return mtmd_bitmap_get_data(ptr.get()); } + std::string id() { return mtmd_bitmap_get_id(ptr.get()); } + void set_id(const char * id) { mtmd_bitmap_set_id(ptr.get(), id); } +}; + +struct bitmaps { + std::vector entries; + ~bitmaps() = default; + // return list of pointers to mtmd_bitmap + // example: + // auto bitmaps_c_ptr = bitmaps.c_ptr(); + // int32_t res = mtmd_tokenize(... bitmaps_c_ptr.data(), bitmaps_c_ptr.size()); + std::vector c_ptr() { + std::vector res(entries.size()); + for (size_t i = 0; i < entries.size(); i++) { + res[i] = entries[i].ptr.get(); + } + return res; + } +}; + +struct input_chunks { + input_chunks_ptr ptr; + input_chunks() = default; + input_chunks(mtmd_input_chunks * chunks) : ptr(chunks) {} + ~input_chunks() = default; + size_t size() { return mtmd_input_chunks_size(ptr.get()); } + const mtmd_input_chunk * operator[](size_t idx) { + return mtmd_input_chunks_get(ptr.get(), idx); + } +}; + +} // namespace mtmd #endif