From c7e351bf4166dfcc33e5aa73b65002dcf946a06c Mon Sep 17 00:00:00 2001 From: Concedo <39025047+LostRuins@users.noreply.github.com> Date: Mon, 4 Nov 2024 15:47:13 +0800 Subject: [PATCH] add exception for ibm granite, then keep using f16 kq mul for HIPBLAS only for now pending ROCM investigation re https://github.com/ggerganov/llama.cpp/pull/10015 --- src/llama.cpp | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/llama.cpp b/src/llama.cpp index 355770826..b1c1ffdc8 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -9733,16 +9733,30 @@ static struct ggml_tensor * llm_build_kqv( cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias, hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f); +#if defined(GGML_USE_HIPBLAS) //workaround for speed regression on rocm + if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_GEMMA2 || model.arch == LLM_ARCH_GRANITE || model.arch == LLM_ARCH_GRANITE_MOE) { + ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32); + } +#else ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32); +#endif cur = ggml_reshape_2d(ctx, cur, n_embd_head_v*n_head, n_tokens); } else { struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); cb(kq, "kq", il); +#if defined(GGML_USE_HIPBLAS) //workaround for speed regression on rocm + if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2 || model.arch == LLM_ARCH_NEMOTRON || model.arch == LLM_ARCH_CHATGLM || model.arch == LLM_ARCH_GRANITE || model.arch == LLM_ARCH_GRANITE_MOE) { + // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs + // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847 + ggml_mul_mat_set_prec(kq, GGML_PREC_F32); + } +#else // note: this op tends to require high floating point range // while for some models F16 is enough, for others it is not, so we default to F32 here ggml_mul_mat_set_prec(kq, GGML_PREC_F32); +#endif if (model.arch == LLM_ARCH_GROK) { // need to do the following: