diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 58e455ae6..b754dd815 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -556,11 +556,8 @@ class TextModel(ModelBase): logger.info(f"gguf: experts used count = {n_experts_used}") if (head_dim := self.hparams.get("head_dim")) is not None: - # Workaround for incorrect AutoConfig value for DeepSeekV3 (is set correctly in DeepSeekV2Model class) - # https://github.com/huggingface/transformers/blob/19224c3642705c5b6988c9f5f4251f83323d05ae/src/transformers/models/deepseek_v3/configuration_deepseek_v3.py#L210 - if self.hparams.get("model_type") != "deepseek_v3": - self.gguf_writer.add_key_length(head_dim) - self.gguf_writer.add_value_length(head_dim) + self.gguf_writer.add_key_length(head_dim) + self.gguf_writer.add_value_length(head_dim) self.gguf_writer.add_file_type(self.ftype) logger.info(f"gguf: file type = {self.ftype}") @@ -1901,9 +1898,7 @@ class LlamaModel(TextModel): hparams = self.hparams self.gguf_writer.add_vocab_size(hparams["vocab_size"]) - if "head_dim" in hparams: - rope_dim = hparams["head_dim"] - else: + if (rope_dim := hparams.get("head_dim")) is None: rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"] self.gguf_writer.add_rope_dimension_count(rope_dim) @@ -1985,7 +1980,8 @@ class LlamaModel(TextModel): if rope_scaling := self.find_hparam(["rope_scaling"], optional=True): if rope_scaling.get("rope_type", '').lower() == "llama3": base = self.hparams.get("rope_theta", 10000.0) - dim = self.hparams.get("head_dim", self.hparams["hidden_size"] // self.hparams["num_attention_heads"]) + if (dim := self.hparams.get("head_dim")) is None: + dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"] freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) factor = rope_scaling.get("factor", 8.0) @@ -2321,9 +2317,7 @@ class DeciModel(TextModel): hparams = self.hparams self.gguf_writer.add_vocab_size(hparams["vocab_size"]) - if "head_dim" in hparams: - rope_dim = hparams["head_dim"] - else: + if (rope_dim := hparams.get("head_dim")) is None: rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"] self.gguf_writer.add_rope_dimension_count(rope_dim) @@ -2363,7 +2357,8 @@ class DeciModel(TextModel): if rope_scaling := self.find_hparam(["rope_scaling"], optional=True): if rope_scaling.get("rope_type", '').lower() == "llama3": base = self.hparams.get("rope_theta", 10000.0) - dim = self.hparams.get("head_dim", self.hparams["hidden_size"] // self.hparams["num_attention_heads"]) + if (dim := self.hparams.get("head_dim")) is None: + dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"] freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) factor = rope_scaling.get("factor", 8.0) @@ -3681,9 +3676,7 @@ class InternLM3Model(TextModel): hparams = self.hparams self.gguf_writer.add_vocab_size(hparams["vocab_size"]) - if "head_dim" in hparams: - rope_dim = hparams["head_dim"] - else: + if (rope_dim := hparams.get("head_dim")) is None: rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"] self.gguf_writer.add_rope_dimension_count(rope_dim) @@ -5098,9 +5091,7 @@ class DeepseekModel(TextModel): def set_gguf_parameters(self): super().set_gguf_parameters() hparams = self.hparams - if "head_dim" in hparams: - rope_dim = hparams["head_dim"] - else: + if (rope_dim := hparams.get("head_dim")) is None: rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"] self.gguf_writer.add_rope_dimension_count(rope_dim) @@ -5990,7 +5981,8 @@ class ExaoneModel(TextModel): if rope_scaling := self.find_hparam(["rope_scaling"], optional=True): if rope_scaling.get("rope_type", '').lower() == "llama3": base = self.hparams.get("rope_theta", 10000.0) - dim = self.hparams.get("head_dim", self.hparams["hidden_size"] // self.hparams["num_attention_heads"]) + if (dim := self.hparams.get("head_dim")) is None: + dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"] freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) factor = rope_scaling.get("factor", 8.0) @@ -6102,7 +6094,8 @@ class BailingMoeModel(TextModel): def set_gguf_parameters(self): super().set_gguf_parameters() hparams = self.hparams - rope_dim = hparams.get("head_dim") or hparams["hidden_size"] // hparams["num_attention_heads"] + if (rope_dim := hparams.get("head_dim")) is None: + rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"] self.gguf_writer.add_rope_dimension_count(rope_dim) rope_scaling = self.hparams.get("rope_scaling") or {} @@ -6134,7 +6127,8 @@ class BailingMoeModel(TextModel): n_head = self.hparams["num_attention_heads"] n_kv_head = self.hparams.get("num_key_value_heads") n_embd = self.hparams["hidden_size"] - head_dim = self.hparams.get("head_dim") or n_embd // n_head + if (head_dim := self.hparams.get("head_dim")) is None: + head_dim = n_embd // n_head output_name = self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 55d3918e0..9e12da330 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -901,12 +901,6 @@ struct ggml_context { struct ggml_object * objects_end; }; -struct ggml_context_container { - bool used; - - struct ggml_context context; -}; - // // data types // diff --git a/src/llama-chat.cpp b/src/llama-chat.cpp index bc4fa05a7..0839cad3e 100644 --- a/src/llama-chat.cpp +++ b/src/llama-chat.cpp @@ -333,7 +333,7 @@ int32_t llm_chat_apply_template( std::string role(message->role); if (role == "system") { // there is no system message for gemma, but we will merge it with user prompt, so nothing is broken - system_prompt = trim(message->content); + system_prompt += trim(message->content); continue; } // in gemma, "assistant" is "model" @@ -355,7 +355,7 @@ int32_t llm_chat_apply_template( std::string role(message->role); if (role == "system") { // there is no system message support, we will merge it with user prompt - system_prompt = message->content; + system_prompt += message->content; continue; } else if (role == "user") { ss << "Human: "; diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 6b80ca6bf..1930841fa 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -205,7 +205,7 @@ struct clip_hparams { float eps = 1e-6; float rope_theta = 0.0; - std::vector image_grid_pinpoints; + std::vector image_res_candidates; // for llava-uhd style models int32_t image_crop_resolution; std::unordered_set vision_feature_layer; int32_t attn_window_size = 0; @@ -2149,8 +2149,7 @@ struct clip_model_loader { if (is_vision) { get_u32(KEY_IMAGE_SIZE, hparams.image_size); get_u32(KEY_PATCH_SIZE, hparams.patch_size); - get_u32(KEY_IMAGE_CROP_RESOLUTION, hparams.image_crop_resolution, false); - get_arr_int(KEY_IMAGE_GRID_PINPOINTS, hparams.image_grid_pinpoints, false); + get_u32(KEY_IMAGE_CROP_RESOLUTION, hparams.image_crop_resolution, false); get_i32(KEY_MINICPMV_VERSION, hparams.minicpmv_version, false); // legacy } else if (is_audio) { @@ -2160,6 +2159,20 @@ struct clip_model_loader { GGML_ASSERT(false && "unknown modality"); } + // for pinpoints, we need to convert it into a list of resolution candidates + { + std::vector pinpoints; + get_arr_int(KEY_IMAGE_GRID_PINPOINTS, pinpoints, false); + if (!pinpoints.empty()) { + for (size_t i = 0; i < pinpoints.size(); i += 2) { + hparams.image_res_candidates.push_back({ + pinpoints[i], + pinpoints[i+1], + }); + } + } + } + // default warmup value hparams.warmup_image_size = hparams.image_size; @@ -2282,16 +2295,7 @@ struct clip_model_loader { { hparams.rope_theta = 10000.0f; get_u32(KEY_PROJ_SCALE_FACTOR, hparams.proj_scale_factor); - - // borrowed from llava-1.6 - const int isize = hparams.image_size; - hparams.image_grid_pinpoints = { - isize, isize*2, // 336, 672 - isize*2, isize, // 672, 336 - isize*2, isize*2, // 672, 672 - isize*3, isize, // 1008, 336 - isize, isize*3, // 336, 1008 - }; + set_llava_uhd_res_candidates(model, 3); } break; case PROJECTOR_TYPE_ULTRAVOX: case PROJECTOR_TYPE_QWEN2A: @@ -2730,6 +2734,21 @@ struct clip_model_loader { output[i] = values[i]; } } + + void set_llava_uhd_res_candidates(clip_model & model, const int max_patches_per_side) { + auto & hparams = model.hparams; + for (int x = 1; x <= max_patches_per_side; x++) { + for (int y = 1; y <= max_patches_per_side; y++) { + if (x == 1 && y == 1) { + continue; // skip the first point + } + hparams.image_res_candidates.push_back(clip_image_size{ + x*hparams.image_size, + y*hparams.image_size, + }); + } + } + } }; struct clip_init_result clip_init(const char * fname, struct clip_context_params ctx_params) { @@ -3195,36 +3214,41 @@ struct llava_uhd { bool padding_refined = false; // if true, refine image will be padded to the grid size (e.g. llava-1.6) }; - static int get_max_slices(struct clip_ctx * ctx) { - if (clip_is_minicpmv(ctx)) { - return 9; - } - return 0; - } - static slice_instructions get_slice_instructions(struct clip_ctx * ctx, const clip_image_size & original_size) { slice_instructions res; const int patch_size = clip_get_patch_size(ctx); const int slice_size = clip_get_image_size(ctx); - const int max_slice_nums = get_max_slices(ctx); const int original_width = original_size.width; const int original_height = original_size.height; - const float log_ratio = log((float)original_width / original_height); - const float ratio = (float)original_width * original_height / (slice_size * slice_size); - const int multiple = fmin(ceil(ratio), max_slice_nums); - const bool has_slices = (multiple > 1); - const bool has_pinpoints = !ctx->model.hparams.image_grid_pinpoints.empty(); + + const bool has_slices = original_size.width > slice_size || original_size.height > slice_size; + const bool has_pinpoints = !ctx->model.hparams.image_res_candidates.empty(); + + if (!has_slices) { + // skip slicing logic + res.overview_size = clip_image_size{slice_size, slice_size}; + res.refined_size = clip_image_size{0, 0}; + res.grid_size = clip_image_size{0, 0}; + + return res; + } if (has_pinpoints) { // has pinpoints, use them to calculate the grid size (e.g. llava-1.6) auto refine_size = llava_uhd::select_best_resolution( - ctx->model.hparams.image_grid_pinpoints, - original_size); + original_size, + ctx->model.hparams.image_res_candidates); res.overview_size = clip_image_size{slice_size, slice_size}; res.refined_size = refine_size; res.grid_size = clip_image_size{0, 0}; res.padding_refined = true; + LOG_DBG("%s: using pinpoints for slicing\n", __func__); + LOG_DBG("%s: original size: %d x %d, overview size: %d x %d, refined size: %d x %d\n", + __func__, original_width, original_height, + res.overview_size.width, res.overview_size.height, + res.refined_size.width, res.refined_size.height); + for (int y = 0; y < refine_size.height; y += slice_size) { for (int x = 0; x < refine_size.width; x += slice_size) { slice_coordinates slice; @@ -3233,13 +3257,16 @@ struct llava_uhd { slice.size.width = std::min(slice_size, refine_size.width - x); slice.size.height = std::min(slice_size, refine_size.height - y); res.slices.push_back(slice); - if (x == 0) { - res.grid_size.width++; - } + LOG_DBG("%s: slice %d: x=%d, y=%d, size=%dx%d\n", + __func__, (int)res.slices.size() - 1, + slice.x, slice.y, slice.size.width, slice.size.height); } - res.grid_size.height++; } + res.grid_size.height = refine_size.height / slice_size; + res.grid_size.width = refine_size.width / slice_size; + LOG_DBG("%s: grid size: %d x %d\n", __func__, res.grid_size.width, res.grid_size.height); + return res; } @@ -3248,17 +3275,23 @@ struct llava_uhd { auto best_size = get_best_resize(original_size, slice_size, patch_size, !has_slices); res.overview_size = best_size; - if (!has_slices) { - // skip slicing logic - res.refined_size = clip_image_size{0, 0}; - res.grid_size = clip_image_size{0, 0}; + { + const int max_slice_nums = 9; // TODO: this is only used by minicpmv, maybe remove it + const float log_ratio = log((float)original_width / original_height); + const float ratio = (float)original_width * original_height / (slice_size * slice_size); + const int multiple = fmin(ceil(ratio), max_slice_nums); - } else { auto best_grid = get_best_grid(max_slice_nums, multiple, log_ratio); auto refine_size = get_refine_size(original_size, best_grid, slice_size, patch_size, true); res.grid_size = best_grid; res.refined_size = refine_size; + LOG_DBG("%s: original size: %d x %d, overview size: %d x %d, refined size: %d x %d, grid size: %d x %d\n", + __func__, original_width, original_height, + res.overview_size.width, res.overview_size.height, + res.refined_size.width, res.refined_size.height, + res.grid_size.width, res.grid_size.height); + int width = refine_size.width; int height = refine_size.height; int grid_x = int(width / best_grid.width); @@ -3275,7 +3308,9 @@ struct llava_uhd { slice.size.width = grid_x; slice.size.height = grid_y; res.slices.push_back(slice); - // LOG_INF("slice %d: %d %d %d %d\n", ic, patches_i, patches_j, grid_x, grid_y); + LOG_DBG("%s: slice %d: x=%d, y=%d, size=%dx%d\n", + __func__, (int)res.slices.size() - 1, + slice.x, slice.y, slice.size.width, slice.size.height); } } } @@ -3333,48 +3368,55 @@ private: return res; } + static clip_image_size resize_maintain_aspect_ratio(const clip_image_size & orig, const clip_image_size & target_max) { + float scale_width = static_cast(target_max.width) / orig.width; + float scale_height = static_cast(target_max.height) / orig.height; + float scale = std::min(scale_width, scale_height); + return clip_image_size{ + static_cast(orig.width * scale), + static_cast(orig.height * scale), + }; + } + /** * Selects the best resolution from a list of possible resolutions based on the original size. * + * For example, when given a list of resolutions: + * - 100x100 + * - 200x100 + * - 100x200 + * - 200x200 + * + * And an input image of size 111x200, then 100x200 is the best fit (least wasted resolution). + * * @param original_size The original size of the image * @param possible_resolutions A list of possible resolutions * @return The best fit resolution */ static clip_image_size select_best_resolution(const clip_image_size & original_size, const std::vector & possible_resolutions) { - int original_width = original_size.width; - int original_height = original_size.height; clip_image_size best_fit; + int min_wasted_area = std::numeric_limits::max(); int max_effective_resolution = 0; - int min_wasted_resolution = std::numeric_limits::max(); - for (const auto & resolution : possible_resolutions) { - int width = resolution.width; - int height = resolution.height; - float scale = std::min(static_cast(width) / original_width, static_cast(height) / original_height); - int downscaled_width = static_cast(original_width * scale); - int downscaled_height = static_cast(original_height * scale); - int effective_resolution = std::min(downscaled_width * downscaled_height, original_width * original_height); - int wasted_resolution = (width * height) - effective_resolution; - // LOG_INF("resolution: %d %d, scale: %f, downscaled: %d %d, effective: %d, wasted: %d\n", width, height, scale, downscaled_width, downscaled_height, effective_resolution, wasted_resolution); - if (effective_resolution > max_effective_resolution || (effective_resolution == max_effective_resolution && wasted_resolution < min_wasted_resolution)) { + for (const clip_image_size & candidate : possible_resolutions) { + auto target_size = resize_maintain_aspect_ratio(original_size, candidate); + int effective_resolution = std::min( + target_size.width * target_size.height, + original_size.width * original_size.height); + int wasted_area = (candidate.width * candidate.height) - effective_resolution; + + if (effective_resolution > max_effective_resolution || (effective_resolution == max_effective_resolution && wasted_area < min_wasted_area)) { max_effective_resolution = effective_resolution; - min_wasted_resolution = wasted_resolution; - best_fit = resolution; + min_wasted_area = wasted_area; + best_fit = candidate; } + + LOG_DBG("%s: candidate: %d x %d, target: %d x %d, wasted: %d, effective: %d\n", __func__, candidate.width, candidate.height, target_size.width, target_size.height, wasted_area, effective_resolution); } return best_fit; } - // used by llava 1.6 with custom list of pinpoints - static clip_image_size select_best_resolution(const std::vector & pinpoints, const clip_image_size & original_size) { - std::vector possible_resolutions; // TODO @ngxson : construct this inside hparams, not here - for (size_t i = 0; i < pinpoints.size(); i += 2) { - possible_resolutions.push_back(clip_image_size{pinpoints[i], pinpoints[i+1]}); - } - return select_best_resolution(original_size, possible_resolutions); - } - static int ensure_divide(int length, int patch_size) { return std::max(static_cast(std::round(static_cast(length) / patch_size) * patch_size), patch_size); } @@ -3498,7 +3540,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str return true; } else if (ctx->proj_type() == PROJECTOR_TYPE_LLAMA4) { - GGML_ASSERT(!params.image_grid_pinpoints.empty()); + GGML_ASSERT(!params.image_res_candidates.empty()); auto const inst = llava_uhd::get_slice_instructions(ctx, original_size); std::vector imgs = llava_uhd::slice_image(img, inst); @@ -3538,7 +3580,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str res_imgs->entries.push_back(std::move(res)); return true; - } else if (!params.image_grid_pinpoints.empty()) { + } else if (!params.image_res_candidates.empty()) { // "spatial_unpad" with "anyres" processing for llava-1.6 auto const inst = llava_uhd::get_slice_instructions(ctx, original_size); std::vector imgs = llava_uhd::slice_image(img, inst); @@ -3598,17 +3640,6 @@ const char * clip_patch_merge_type(const struct clip_ctx * ctx) { return ctx->model.hparams.mm_patch_merge_type == PATCH_MERGE_SPATIAL_UNPAD ? "spatial_unpad" : "flat"; } -const int32_t * clip_image_grid(const struct clip_ctx * ctx) { - if (ctx->model.hparams.image_grid_pinpoints.size()) { - return &ctx->model.hparams.image_grid_pinpoints.front(); - } - return nullptr; -} - -size_t get_clip_image_grid_size(const struct clip_ctx * ctx) { - return ctx->model.hparams.image_grid_pinpoints.size(); -} - int clip_n_output_tokens_x(const struct clip_ctx * ctx, struct clip_image_f32 * img) { const auto & params = ctx->model.hparams; const int n_total = clip_n_output_tokens(ctx, img); diff --git a/tools/mtmd/clip.h b/tools/mtmd/clip.h index 5956fcea4..22ca53bba 100644 --- a/tools/mtmd/clip.h +++ b/tools/mtmd/clip.h @@ -46,9 +46,6 @@ int32_t clip_get_hidden_size(const struct clip_ctx * ctx); // TODO: should be enum, not string const char * clip_patch_merge_type(const struct clip_ctx * ctx); -const int32_t * clip_image_grid(const struct clip_ctx * ctx); -size_t get_clip_image_grid_size(const struct clip_ctx * ctx); - int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * img); // for M-RoPE, this will be the number of token positions in X and Y directions diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index 8573f1143..e38297383 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -501,7 +501,10 @@ struct mtmd_tokenizer { || ctx->slice_tmpl == MTMD_SLICE_TMPL_MINICPMV_2_6 || ctx->slice_tmpl == MTMD_SLICE_TMPL_LLAMA4 ) { + const int n_col = batch_f32.grid_x; + const int n_row = batch_f32.grid_y; // split batch into chunks of single images + // NOTE: batch_f32 will be invalidated after this call auto chunks = split_batch_to_chunk(std::move(batch_f32), bitmap->id); GGML_ASSERT(chunks.size() > 0); @@ -521,8 +524,7 @@ struct mtmd_tokenizer { // add slices (or tiles) if (!chunks.empty()) { - const int n_col = batch_f32.grid_x; - const int n_row = batch_f32.grid_y; + GGML_ASSERT((int)chunks.size() == n_row * n_col); if (ctx->tok_slices_start != LLAMA_TOKEN_NULL) { add_text({ctx->tok_slices_start}); }