remove debug prints

This commit is contained in:
Concedo 2025-06-06 14:08:57 +08:00
parent ca99f79ea9
commit 5f38594dc0
2 changed files with 2 additions and 8 deletions

View file

@ -1246,7 +1246,6 @@ static __global__ void flash_attn_ext_f16(
}
#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
if (ncols1*ncols2 > 32) {
printf("\nCUDA_ARCH:%d, DKQ:%d, DV:%d, ncols1:%d, ncols2:%d, nwarps:%d, ntiles:%d, ne00:%d, ne01:%d\n",__CUDA_ARCH__,DKQ, DV, ncols1, ncols2, nwarps, ntiles,ne00,ne01);
NO_DEVICE_CODE;
return;
}

View file

@ -15,26 +15,21 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_con
if constexpr (ncols2 <= 8) {
if (Q->ne[1] <= 8/ncols2) {
printf("\nCase B: %d %d %d %d %d\n",DKQ,DV,8/ncols2,ncols2,Q->ne[1]);
ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 8/ncols2, ncols2>(ctx, dst);
return;
}
}
if (Q->ne[1] <= 16/ncols2) {
printf("\nCase C: %d %d %d %d %d\n",DKQ,DV,16/ncols2,ncols2,Q->ne[1]);
ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 16/ncols2, ncols2>(ctx, dst);
return;
}
if (ggml_cuda_highest_compiled_arch(cc) <= GGML_CUDA_CC_TURING || Q->ne[1] <= 32/ncols2) {
printf("\nCase D: %d %d %d %d %d\n",DKQ,DV,32/ncols2,ncols2,Q->ne[1]);
ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 32/ncols2, ncols2>(ctx, dst);
return;
}
printf("\nDBG: %d %d %d\n",ggml_cuda_highest_compiled_arch(cc),cc,GGML_CUDA_CC_TURING);
printf("\nCase E: %d %d %d %d %d\n",DKQ,DV,64/ncols2,ncols2,Q->ne[1]);
ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 64/ncols2, ncols2>(ctx, dst);
}
@ -52,7 +47,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_con
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
const int gqa_ratio = Q->ne[2] / K->ne[2];
printf("\ngqa_ratio is %d\n",gqa_ratio);
if (use_gqa_opt && gqa_ratio % 8 == 0) {
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 8>(ctx, dst);
return;
@ -77,7 +72,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
const ggml_tensor * K = dst->src[1];
const ggml_tensor * V = dst->src[2];
const ggml_tensor * mask = dst->src[3];
printf("\nQ->ne[0] is %d\n",Q->ne[0]);
switch (Q->ne[0]) {
case 64:
GGML_ASSERT(V->ne[0] == 64);