diff --git a/examples/llava/clip-impl.h b/examples/llava/clip-impl.h index 685d6e7e0..c93d7877c 100644 --- a/examples/llava/clip-impl.h +++ b/examples/llava/clip-impl.h @@ -22,6 +22,8 @@ #define KEY_HAS_QWEN2VL_MERGER "clip.has_qwen2vl_merger" #define KEY_USE_GELU "clip.use_gelu" #define KEY_USE_SILU "clip.use_silu" +#define KEY_USE_GLU_MLP "clip.use_glu_mlp" +#define KEY_USE_RMS_NORM "clip.use_rms_norm" #define KEY_N_EMBD "clip.%s.embedding_length" #define KEY_N_FF "clip.%s.feed_forward_length" #define KEY_N_BLOCK "clip.%s.block_count" @@ -40,6 +42,8 @@ #define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type" #define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints" #define KEY_IMAGE_CROP_RESOLUTION "clip.vision.image_crop_resolution" +#define KEY_FULLATTN_BLK_IDX "clip.vision.fullatt_block_indexes" +#define KEY_ATTN_WINDOW_SIZE "clip.vision.window_size" // @@ -58,6 +62,7 @@ #define TN_ATTN_OUTPUT "%s.blk.%d.attn_out.%s" #define TN_FFN_DOWN "%s.blk.%d.ffn_down.%s" #define TN_FFN_UP "%s.blk.%d.ffn_up.%s" +#define TN_FFN_GATE "%s.blk.%d.ffn_gate.%s" #define TN_LN_1 "%s.blk.%d.ln1.%s" #define TN_LN_2 "%s.blk.%d.ln2.%s" #define TN_LN_PRE "%s.pre_ln.%s" diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp index 1145c816c..fe87aa2b9 100644 --- a/examples/llava/clip.cpp +++ b/examples/llava/clip.cpp @@ -183,6 +183,7 @@ struct clip_hparams { std::vector image_grid_pinpoints; int32_t image_crop_resolution; std::unordered_set vision_feature_layer; + std::vector full_attn_layers; }; struct clip_layer { @@ -208,6 +209,9 @@ struct clip_layer { struct ggml_tensor * ff_o_w = nullptr; struct ggml_tensor * ff_o_b = nullptr; + struct ggml_tensor * ff_g_w = NULL; + struct ggml_tensor * ff_g_b = NULL; + // layernorm 2 struct ggml_tensor * ln_2_w = nullptr; struct ggml_tensor * ln_2_b = nullptr; @@ -331,6 +335,8 @@ struct clip_ctx { float image_std[3]; bool use_gelu = false; bool use_silu = false; + bool use_glu_mlp = false; + bool use_rms_norm = false; int32_t ftype = 1; struct gguf_context * ctx_gguf = nullptr; @@ -576,6 +582,7 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im const int n_head = hparams.n_head; const int d_head = hidden_size / n_head; const float eps = hparams.eps; + const bool use_window_attn = hparams.full_attn_layers.size() > 0; int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4}; const int batch_size = imgs->size; @@ -626,8 +633,10 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im // inp = ggml_add(ctx0, inp, ggml_repeat(ctx0, model.patch_bias, inp)); inp = ggml_add(ctx0, inp, model.patch_bias); } - struct ggml_tensor * embeddings = inp; - struct ggml_tensor * pos_embed = nullptr; + struct ggml_tensor * embeddings = inp; + struct ggml_tensor * pos_embed = nullptr; + struct ggml_tensor * window_mask = nullptr; + struct ggml_tensor * window_idx = nullptr; if (ctx->has_llava_projector) { // concat class_embeddings and patch_embeddings @@ -679,6 +688,28 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im const auto & vision_feature_layer = hparams.vision_feature_layer; // loop over layers + + if (use_window_attn) { + window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions / 4); + ggml_set_name(window_idx, "window_idx"); + ggml_set_input(window_idx); + + // mask for window attention + window_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, num_positions, num_positions); + ggml_set_name(window_mask, "window_mask"); + ggml_set_input(window_mask); + + // embeddings shape: [hidden_size, patches_w * patches_h, batch_size] + GGML_ASSERT(batch_size == 1); + embeddings = ggml_reshape_2d(ctx0, embeddings, hidden_size * 4, patches_w * patches_h * batch_size / 4); + embeddings = ggml_get_rows(ctx0, embeddings, window_idx); + embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size, patches_w * patches_h, batch_size); + + positions = ggml_reshape_2d(ctx0, positions, 16, num_position_ids / 4 / 4); + positions = ggml_get_rows(ctx0, positions, window_idx); + positions = ggml_reshape_1d(ctx0, positions, num_position_ids); + } + for (int il = 0; il < ctx->max_feature_layer; il++) { struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states @@ -691,9 +722,12 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im //const size_t nb_q_w = model.layers[il].q_w->nb[0]; // layernorm1 - { + if (ctx->use_rms_norm) { + cur = ggml_rms_norm(ctx0, cur, eps); + cur = ggml_mul(ctx0, cur, model.layers[il].ln_1_w); + } + else { cur = ggml_norm(ctx0, cur, eps); - cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_1_w), model.layers[il].ln_1_b); } @@ -733,7 +767,14 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im V = ggml_reshape_3d(ctx0, V, num_positions, d_head, n_head * batch_size); struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); - KQ = ggml_soft_max_ext(ctx0, KQ, nullptr, 1.0f / sqrtf((float)d_head), 0.0f); + const bool inlist = std::find(hparams.full_attn_layers.begin(), hparams.full_attn_layers.end(), il) != hparams.full_attn_layers.end(); + const bool full_attn = use_window_attn ? inlist : true; + if (full_attn) { + KQ = ggml_soft_max_ext(ctx0, KQ, nullptr, 1.0f / sqrtf((float)d_head), 0.0f); + } else { + KQ = ggml_soft_max_ext(ctx0, KQ, window_mask, 1.0f, 0.0f); + } + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ); KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_positions, n_head, batch_size); KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3); @@ -750,25 +791,50 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im embeddings = cur; // embeddings = residual, cur = hidden_states // layernorm2 - { + if (ctx->use_rms_norm) { + cur = ggml_rms_norm(ctx0, cur, eps); + cur = ggml_mul(ctx0, cur, model.layers[il].ln_2_w); + } else { cur = ggml_norm(ctx0, cur, eps); - cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_2_w), model.layers[il].ln_2_b); } - cur = ggml_mul_mat(ctx0, model.layers[il].ff_i_w, cur); - cur = ggml_add(ctx0, cur, model.layers[il].ff_i_b); + // mlp + if (ctx->use_glu_mlp) { + // ffn_up + auto cur_up = ggml_mul_mat(ctx0, model.layers[il].ff_o_w, cur); + cur_up = ggml_add(ctx0, cur_up, model.layers[il].ff_o_b); - if (ctx->use_gelu) { - cur = ggml_gelu_inplace(ctx0, cur); - } else if (ctx->use_silu) { - cur = ggml_silu_inplace(ctx0, cur); - } else { - cur = ggml_gelu_quick_inplace(ctx0, cur); + auto cur_gate = ggml_mul_mat(ctx0, model.layers[il].ff_g_w, cur); + cur_gate = ggml_add(ctx0, cur_gate, model.layers[il].ff_g_b); + if (ctx->use_gelu) { + cur_gate = ggml_gelu_inplace(ctx0, cur_gate); + } else if (ctx->use_silu) { + cur_gate = ggml_silu_inplace(ctx0, cur_gate); + } else { + cur_gate = ggml_gelu_quick_inplace(ctx0, cur_gate); + } + cur = ggml_mul(ctx0, cur_gate, cur_up); + + // ffn_down + cur = ggml_mul_mat(ctx0, model.layers[il].ff_i_w, cur); + cur = ggml_add(ctx0, cur, model.layers[il].ff_i_b); } + else { + cur = ggml_mul_mat(ctx0, model.layers[il].ff_i_w, cur); + cur = ggml_add(ctx0, cur, model.layers[il].ff_i_b); - cur = ggml_mul_mat(ctx0, model.layers[il].ff_o_w, cur); - cur = ggml_add(ctx0, cur, model.layers[il].ff_o_b); + if (ctx->use_gelu) { + cur = ggml_gelu_inplace(ctx0, cur); + } else if (ctx->use_silu) { + cur = ggml_silu_inplace(ctx0, cur); + } else { + cur = ggml_gelu_quick_inplace(ctx0, cur); + } + + cur = ggml_mul_mat(ctx0, model.layers[il].ff_o_w, cur); + cur = ggml_add(ctx0, cur, model.layers[il].ff_o_b); + } // residual 2 cur = ggml_add(ctx0, embeddings, cur); @@ -778,10 +844,17 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im // post-layernorm if (model.post_ln_w) { - embeddings = ggml_norm(ctx0, embeddings, eps); - ggml_set_name(embeddings, "post_ln"); + if (ctx->use_rms_norm) { + embeddings = ggml_rms_norm(ctx0, embeddings, eps); + ggml_set_name(embeddings, "post_ln"); - embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.post_ln_w), model.post_ln_b); + embeddings = ggml_mul(ctx0, embeddings, model.post_ln_w); + } else { + embeddings = ggml_norm(ctx0, embeddings, eps); + ggml_set_name(embeddings, "post_ln"); + + embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.post_ln_w), model.post_ln_b); + } } // final layer is a vision feature layer @@ -1095,6 +1168,18 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im embeddings = ggml_add(ctx0, embeddings, model.mm_1_b); } + if (use_window_attn) { + struct ggml_tensor * inv_window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions / 4); + ggml_set_name(inv_window_idx, "inv_window_idx"); + ggml_set_input(inv_window_idx); + + // embeddings shape: [hidden_size, patches_w * patches_h, batch_size] + GGML_ASSERT(batch_size == 1); + embeddings = ggml_reshape_2d(ctx0, embeddings, hparams.projection_dim, patches_w * patches_h / 4); + embeddings = ggml_get_rows(ctx0, embeddings, inv_window_idx); + embeddings = ggml_reshape_3d(ctx0, embeddings, hparams.projection_dim, patches_w * patches_h / 4, batch_size); + } + // build the graph ggml_build_forward_expand(gf, embeddings); @@ -1203,6 +1288,8 @@ struct clip_model_loader { get_bool(KEY_USE_GELU, ctx_clip.use_gelu, false); get_bool(KEY_USE_SILU, ctx_clip.use_silu, false); + get_bool(KEY_USE_GLU_MLP, ctx_clip.use_glu_mlp, false); + get_bool(KEY_USE_RMS_NORM, ctx_clip.use_rms_norm, false); auto & hparams = ctx_clip.vision_model.hparams; get_u32(string_format(KEY_N_EMBD, "vision"), hparams.hidden_size); @@ -1215,6 +1302,7 @@ struct clip_model_loader { get_u32(KEY_PATCH_SIZE, hparams.patch_size); get_u32(KEY_IMAGE_CROP_RESOLUTION, hparams.image_crop_resolution, false); get_arr_int(KEY_IMAGE_GRID_PINPOINTS, hparams.image_grid_pinpoints, false); + get_arr_int(KEY_FULLATTN_BLK_IDX, hparams.full_attn_layers, false); { std::string mm_patch_merge_type; @@ -1330,14 +1418,16 @@ struct clip_model_loader { layer.ln_2_w = get_tensor(string_format(TN_LN_2, "v", il, "weight"), false); layer.ff_i_w = get_tensor(string_format(TN_FFN_DOWN, "v", il, "weight")); layer.ff_o_w = get_tensor(string_format(TN_FFN_UP, "v", il, "weight")); + layer.ff_g_w = get_tensor(string_format(TN_FFN_GATE, "v", il, "weight"), ctx_clip.use_glu_mlp); layer.k_b = get_tensor(string_format(TN_ATTN_K, "v", il, "bias"), false); layer.q_b = get_tensor(string_format(TN_ATTN_Q, "v", il, "bias"), false); layer.v_b = get_tensor(string_format(TN_ATTN_V, "v", il, "bias"), false); layer.o_b = get_tensor(string_format(TN_ATTN_OUTPUT, "v", il, "bias"), false); - layer.ln_1_b = get_tensor(string_format(TN_LN_1, "v", il, "bias"), false); - layer.ln_2_b = get_tensor(string_format(TN_LN_2, "v", il, "bias"), false); + layer.ln_1_b = get_tensor(string_format(TN_LN_1, "v", il, "bias"), !ctx_clip.use_rms_norm); + layer.ln_2_b = get_tensor(string_format(TN_LN_2, "v", il, "bias"), !ctx_clip.use_rms_norm); layer.ff_i_b = get_tensor(string_format(TN_FFN_DOWN, "v", il, "bias"), false); layer.ff_o_b = get_tensor(string_format(TN_FFN_UP, "v", il, "bias"), false); + layer.ff_g_b = get_tensor(string_format(TN_FFN_GATE, "v", il, "bias"), ctx_clip.use_glu_mlp); } switch (ctx_clip.proj_type) { diff --git a/examples/llava/qwen2_vl_surgery.py b/examples/llava/qwen2_vl_surgery.py index c87606b4f..8f7a94e5c 100644 --- a/examples/llava/qwen2_vl_surgery.py +++ b/examples/llava/qwen2_vl_surgery.py @@ -5,10 +5,12 @@ import torch import numpy as np from gguf import * from transformers import ( - Qwen2VLForConditionalGeneration, - Qwen2VLProcessor, AutoProcessor, - Qwen2VLConfig + Qwen2VLForConditionalGeneration, + Qwen2_5_VLForConditionalGeneration, + Qwen2VLProcessor, + Qwen2VLConfig, + Qwen2_5_VLConfig, ) @@ -18,62 +20,80 @@ VISION = "clip.vision" def k(raw_key: str, arch: str) -> str: return raw_key.format(arch=arch) +class VL2: -def to_gguf_name(name: str) -> str: - og = name - name = name.replace("text_model", "t").replace("vision_model", "v") - name = name.replace("blocks", "blk").replace("embeddings.", "") - name = name.replace("attn.", "attn_") - name = name.replace("mlp.fc1", "ffn_down").replace("mlp.fc2", "ffn_up").replace("proj.", "out.") - # name = name.replace("layrnorm", "ln").replace("layer_norm", "ln").replace("layernorm", "ln") - name = name.replace("norm1", "ln1").replace("norm2", "ln2") - name = name.replace("merger.mlp", 'mm') - print(f"[to_gguf_name] {og} --> {name}") - return name + @staticmethod + def to_gguf_name(name: str) -> str: + og = name + name = name.replace("text_model", "t").replace("vision_model", "v") + name = name.replace("blocks", "blk").replace("embeddings.", "") + name = name.replace("attn.", "attn_") + name = name.replace("mlp.fc1", "ffn_down").replace("mlp.fc2", "ffn_up").replace("proj.", "out.") + # name = name.replace("layrnorm", "ln").replace("layer_norm", "ln").replace("layernorm", "ln") + name = name.replace("norm1", "ln1").replace("norm2", "ln2") + name = name.replace("merger.mlp", 'mm') + print(f"[to_gguf_name] {og} --> {name}") + return name - -def find_vision_tensors(qwen2vl, dtype) -> Dict[str, np.ndarray]: - vision_model = qwen2vl.visual - tensor_map = {} - for name, ten in vision_model.state_dict().items(): - ten = ten.numpy() - if 'qkv' in name: - if ten.ndim == 2: # weight - c3, _ = ten.shape - else: # bias - c3 = ten.shape[0] - assert c3 % 3 == 0 - c = c3 // 3 - wq = ten[:c] - wk = ten[c: c * 2] - wv = ten[c * 2:] - tensor_map[to_gguf_name(f"vision_model.{name}").replace("qkv", "q")] = wq - tensor_map[to_gguf_name(f"vision_model.{name}").replace("qkv", "k")] = wk - tensor_map[to_gguf_name(f"vision_model.{name}").replace("qkv", "v")] = wv - elif 'merger' in name: - if name.endswith("ln_q.weight"): - tensor_map['v.post_ln.weight'] = ten - elif name.endswith("ln_q.bias"): - tensor_map['v.post_ln.bias'] = ten + @classmethod + def find_vision_tensors(cls, qwen2vl, dtype) -> Dict[str, np.ndarray]: + vision_model = qwen2vl.visual + tensor_map = {} + for name, ten in vision_model.state_dict().items(): + ten = ten.numpy() + if 'qkv' in name: + if ten.ndim == 2: # weight + c3, _ = ten.shape + else: # bias + c3 = ten.shape[0] + assert c3 % 3 == 0 + c = c3 // 3 + wq = ten[:c] + wk = ten[c: c * 2] + wv = ten[c * 2:] + tensor_map[cls.to_gguf_name(f"vision_model.{name}").replace("qkv", "q")] = wq + tensor_map[cls.to_gguf_name(f"vision_model.{name}").replace("qkv", "k")] = wk + tensor_map[cls.to_gguf_name(f"vision_model.{name}").replace("qkv", "v")] = wv + elif 'merger' in name: + if name.endswith("ln_q.weight"): + tensor_map['v.post_ln.weight'] = ten + elif name.endswith("ln_q.bias"): + tensor_map['v.post_ln.bias'] = ten + else: + # "merger.mlp.%d.weight/bias" --> "mm.%d.weight/bias" + tensor_map[cls.to_gguf_name(name)] = ten + elif 'patch_embed.proj.weight' in name: + # NOTE: split Conv3D into Conv2Ds + c1, c2, kt, kh, kw = ten.shape + assert kt == 2, "Current implmentation only support temporal_patch_size of 2" + tensor_map["v.patch_embd.weight"] = ten[:, :, 0, ...] + tensor_map["v.patch_embd.weight.1"] = ten[:, :, 1, ...] else: - # "merger.mlp.%d.weight/bias" --> "mm.%d.weight/bias" - tensor_map[to_gguf_name(name)] = ten - elif 'patch_embed.proj.weight' in name: - # NOTE: split Conv3D into Conv2Ds - c1, c2, kt, kh, kw = ten.shape - assert kt == 2, "Current implmentation only support temporal_patch_size of 2" - tensor_map["v.patch_embd.weight"] = ten[:, :, 0, ...] - tensor_map["v.patch_embd.weight.1"] = ten[:, :, 1, ...] - else: - tensor_map[to_gguf_name(f"vision_model.{name}")] = ten + tensor_map[cls.to_gguf_name(f"vision_model.{name}")] = ten - for new_name, ten in tensor_map.items(): - if ten.ndim <= 1 or new_name.endswith("_norm.weight"): - tensor_map[new_name] = ten.astype(np.float32) - else: - tensor_map[new_name] = ten.astype(dtype) - tensor_map["v.position_embd.weight"] = np.zeros([10, 10], dtype=np.float32) # dummy tensor, just here as a placeholder - return tensor_map + for new_name, ten in tensor_map.items(): + if ten.ndim <= 1 or new_name.endswith("_norm.weight"): + tensor_map[new_name] = ten.astype(np.float32) + else: + tensor_map[new_name] = ten.astype(dtype) + tensor_map["v.position_embd.weight"] = np.zeros([10, 10], dtype=np.float32) # dummy tensor, just here as a placeholder + return tensor_map + + +class VL25(VL2): + + @staticmethod + def to_gguf_name(name: str) -> str: + og = name + name = name.replace("text_model", "t").replace("vision_model", "v") + name = name.replace("blocks", "blk").replace("embeddings.", "") + name = name.replace("attn.", "attn_") + name = name.replace("mlp.down_proj", "ffn_down").replace("mlp.up_proj", "ffn_up") + name = name.replace("mlp.gate_proj", "ffn_gate").replace("proj.", "out.") + name = name.replace("norm1", "ln1").replace("norm2", "ln2") + name = name.replace("merger.mlp", 'mm') + print(f"[vl25][to_gguf_name] {og} --> {name}") + return name def main(args): @@ -92,11 +112,18 @@ def main(args): model_path = "" model_name = args.model_name print("model_name: ", model_name) - qwen2vl = Qwen2VLForConditionalGeneration.from_pretrained( - model_name, torch_dtype=dtype, device_map="cpu" - ) - cfg: Qwen2VLConfig = qwen2vl.config # type: ignore[reportAssignmentType] - vcfg = cfg.vision_config + if args.model_type == "qwen2vl": + qwen2vl = Qwen2VLForConditionalGeneration.from_pretrained( + model_name, torch_dtype=dtype, device_map="cpu" + ) + cfg: Qwen2VLConfig = qwen2vl.config # type: ignore[reportAssignmentType] + vcfg = cfg.vision_config + else: + qwen2vl = Qwen2_5_VLForConditionalGeneration.from_pretrained( + model_name, torch_dtype=dtype, device_map="cpu" + ) + cfg: Qwen2_5_VLConfig = qwen2vl.config # type: ignore[reportAssignmentType] + vcfg = cfg.vision_config if os.path.isdir(model_name): local_model = True @@ -125,14 +152,26 @@ def main(args): else: raise ValueError() - tensor_map = find_vision_tensors(qwen2vl, np_dtype) + if args.model_type == "qwen2.5vl": + fout.add_bool("clip.use_glu_mlp", True) # gate linear unit MLP layer in vision model + fout.add_bool("clip.use_rms_norm", True) + fout.add_array("clip.vision.fullatt_block_indexes", vcfg.fullatt_block_indexes) + fout.add_uint32("clip.vision.window_size", vcfg.window_size) + fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), vcfg.hidden_size) + fout.add_uint32("clip.vision.projection_dim", vcfg.out_hidden_size) + else: + fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), vcfg.embed_dim) + fout.add_uint32("clip.vision.projection_dim", vcfg.hidden_size) + + if args.model_type == "qwen2.5vl": + tensor_map = VL25.find_vision_tensors(qwen2vl, np_dtype) + else: + tensor_map = VL2.find_vision_tensors(qwen2vl, np_dtype) for name, data in tensor_map.items(): fout.add_tensor(name, data) fout.add_uint32("clip.vision.patch_size", vcfg.patch_size) fout.add_uint32("clip.vision.image_size", 14 * 40) # some reasonable size that is divable by (14*2) - fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), vcfg.embed_dim) - fout.add_uint32("clip.vision.projection_dim", vcfg.hidden_size) fout.add_uint32(k(KEY_ATTENTION_HEAD_COUNT, VISION), vcfg.num_heads) fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, VISION), 1e-6) fout.add_uint32(k(KEY_BLOCK_COUNT, VISION), vcfg.depth) @@ -160,6 +199,7 @@ def main(args): if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("model_name", nargs='?', default="Qwen/Qwen2-VL-2B-Instruct") + parser.add_argument("--model_type", nargs='?', choices=['qwen2vl', 'qwen2.5vl'], default="qwen2vl") parser.add_argument("--data_type", nargs='?', choices=['fp32', 'fp16'], default="fp32") args = parser.parse_args() main(args)