diff --git a/otherarch/acestep/dit.h b/otherarch/acestep/dit.h index 7cd4a63b1..acbb4aeca 100644 --- a/otherarch/acestep/dit.h +++ b/otherarch/acestep/dit.h @@ -633,6 +633,10 @@ static struct ggml_tensor * dit_ggml_build_self_attn( ? ggml_flash_attn_ext(ctx, q, k, v, mask, scale, 0.0f, 0.0f) : dit_attn_f32(ctx, q, k, v, mask, scale); + if (m->use_flash_attn) { + ggml_flash_attn_ext_set_prec(attn, GGML_PREC_F32); + } + // Both return [D, Nh, S, N] // Reshape: [D, Nh, S, N] -> [D*Nh, S, N] = [H, S, N] attn = ggml_reshape_3d(ctx, attn, Nh * D, S, N); @@ -742,6 +746,10 @@ static struct ggml_tensor * dit_ggml_build_cross_attn( ? ggml_flash_attn_ext(ctx, q, k, v, NULL, scale, 0.0f, 0.0f) : dit_attn_f32(ctx, q, k, v, NULL, scale); + if (m->use_flash_attn) { + ggml_flash_attn_ext_set_prec(attn, GGML_PREC_F32); + } + // Attention output: [D, Nh, S, N], reshape to [H, S, N] attn = ggml_reshape_3d(ctx, attn, Nh * D, S, N);