cleaning up

This commit is contained in:
HimariO 2025-04-01 21:07:36 +08:00
parent c891300c1e
commit fdae70a832
3 changed files with 8 additions and 60 deletions

View file

@ -615,10 +615,6 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
auto inp_1 = ggml_conv_2d(ctx0, model.patch_embeddings_1, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
inp = ggml_add(ctx0, inp, inp_1);
// ggml_build_forward_expand(gf, inp);
// ggml_free(ctx0);
// return gf;
inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 2, 0, 3)); // [w, h, c, b] -> [c, w, h, b]
inp = ggml_reshape_4d(
ctx0, inp,
@ -630,10 +626,6 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
inp = ggml_reshape_3d(
ctx0, inp,
hidden_size, patches_w * patches_h, batch_size);
// ggml_build_forward_expand(gf, inp);
// ggml_free(ctx0);
// return gf;
}
else {
inp = ggml_reshape_3d(ctx0, inp, num_patches, hidden_size, batch_size);
@ -722,18 +714,6 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
embeddings = ggml_reshape_2d(ctx0, embeddings, hidden_size * 4, patches_w * patches_h * batch_size / 4);
embeddings = ggml_get_rows(ctx0, embeddings, inv_window_idx);
embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size, patches_w * patches_h, batch_size);
// positions = ggml_reshape_2d(ctx0, positions, num_position_ids / 4, 4);
// positions = ggml_cont(ctx0, ggml_permute(ctx0, positions, 1, 0, 2, 3));
// positions = ggml_reshape_2d(ctx0, positions, 16, num_position_ids / 16);
// positions = ggml_get_rows(ctx0, positions, inv_window_idx);
// positions = ggml_reshape_2d(ctx0, positions, 4, num_position_ids / 4);
// positions = ggml_cont(ctx0, ggml_permute(ctx0, positions, 1, 0, 2, 3));
// positions = ggml_reshape_1d(ctx0, positions, num_position_ids);
// ggml_build_forward_expand(gf, embeddings);
// ggml_free(ctx0);
// return gf;
}
for (int il = 0; il < ctx->max_feature_layer; il++) {
@ -757,12 +737,6 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_1_w),
model.layers[il].ln_1_b);
}
// if ( il == 0) {
// // build the graph
// ggml_build_forward_expand(gf, cur);
// ggml_free(ctx0);
// return gf;
// }
// self-attention
{
@ -805,17 +779,10 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
KQ = ggml_soft_max_ext(ctx0, KQ, nullptr, 1.0f / sqrtf((float)d_head), 0.0f);
} else {
KQ = ggml_soft_max_ext(ctx0, KQ, window_mask, 1.0f, 0.0f);
// KQ = ggml_scale_inplace(ctx0, KQ, 1.0f / sqrt((float)d_head));
// KQ = ggml_add(ctx0, KQ, window_mask);
// KQ = ggml_soft_max_inplace(ctx0, KQ);
}
// if ( il == 0) {
// // build the graph
// ggml_build_forward_expand(gf, KQ);
// ggml_free(ctx0);
// return gf;
// }
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ);
KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_positions, n_head, batch_size);
@ -831,12 +798,6 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
cur = ggml_add(ctx0, cur, embeddings);
embeddings = cur; // embeddings = residual, cur = hidden_states
// if ( il == 0) {
// // build the graph
// ggml_build_forward_expand(gf, cur);
// ggml_free(ctx0);
// return gf;
// }
// layernorm2
if (ctx->use_rms_norm) {
@ -888,19 +849,8 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
cur = ggml_add(ctx0, embeddings, cur);
embeddings = cur;
// if ( il == 0) {
// // build the graph
// ggml_build_forward_expand(gf, embeddings);
// ggml_free(ctx0);
// return gf;
// }
}
// ggml_build_forward_expand(gf, embeddings);
// ggml_free(ctx0);
// return gf;
// post-layernorm
if (model.post_ln_w) {
if (ctx->use_rms_norm) {
@ -2792,9 +2742,6 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
}
struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
// const int pw = image_size_width / patch_size;
// const int ph = image_size_height / patch_size;
const int mpow = (merge_ratio * merge_ratio);
int* positions_data = (int*)malloc(ggml_nbytes(positions));
@ -2807,6 +2754,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
for (int dx = 0; dx < 2; dx++) {
auto remap = idx[ptr / mpow];
remap = remap * mpow + (ptr % mpow);
// auto remap = ptr;
positions_data[remap] = y + dy;
positions_data[num_patches + remap] = x + dx;
@ -2818,7 +2766,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
}
}
if (positions) ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions));
ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions));
free(positions_data);
}
else if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {

View file

@ -102,7 +102,7 @@ def main(args):
np_dtype = np.float32
ftype = 0
elif args.data_type == 'fp16':
dtype = torch.float32
dtype = torch.float16
np_dtype = np.float16
ftype = 1
else:

View file

@ -771,10 +771,10 @@ enum model_output_type {
};
static void debug_dump_img_embed(struct llava_context * ctx_llava, model_output_type output_type) {
int ih = 140;
int iw = 196;
// int ih = 56;
// int iw = 56;
constexpr int ih = 140;
constexpr int iw = 196;
// constexpr int ih = 56;
// constexpr int iw = 56;
// int n_embd = llama_model_n_embd(llama_get_model(ctx_llava->ctx_llama));
int n_embd = 1280;
int merge = 1;
@ -954,7 +954,7 @@ int main(int argc, char ** argv) {
// debug_test_mrope_2d();
debug_dump_img_embed(ctx_llava, model_output_type::final_layer);
// debug_dump_img_embed(ctx_llava, model_output_type::conv3d);
// debug_dump_img_embed(ctx_llava, model_output_type::last_attn_layer);
// debug_test_get_rows();
// dump_win_attn_mask();
// debug_patch_layout();