From 015022bb53387baa8b23817ac03743705c7d472b Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Wed, 16 Apr 2025 13:37:25 -0500 Subject: [PATCH 1/6] vulkan: enable coopmat2 FA gqa and split_k optimizations more often (#12931) The grouped query attention optmization doesn't require a power of two ratio, the only thing relying on it was the modulo operation written as bitwise &. split_k need not depend on gqa_ratio - enable it any time there's only one workgroup in the X dimension. The shader gets the split index from the x coord, and multiple workgroups in the X dimension (pre-split) indicates a larger FA operation that wouldn't need splitting. --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 6 +++--- ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp | 2 +- tests/test-backend-ops.cpp | 4 +++- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 783a0ff86..0e9b2e813 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -5531,7 +5531,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx uint32_t workgroups_y = (uint32_t)neq2; uint32_t workgroups_z = (uint32_t)neq3; - if (N == 1 && qk_ratio > 1 && is_pow2(qk_ratio) && gqa_ratio <= flash_attention_num_small_rows && + if (N == 1 && qk_ratio > 1 && gqa_ratio <= flash_attention_num_small_rows && qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1) { // grouped query attention - make the N dimension equal to gqa_ratio, reduce // workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1 @@ -5544,8 +5544,8 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx uint32_t split_kv = KV; uint32_t split_k = 1; - if (gqa_ratio > 1 && ctx->device->shader_core_count > 0) { - GGML_ASSERT(workgroups_x == 1); + // Try to use split_k when KV is large enough to be worth the overhead + if (workgroups_x == 1 && ctx->device->shader_core_count > 0 && KV >= 512) { // Try to run two workgroups per SM. split_k = ctx->device->shader_core_count * 2 / workgroups_y; if (split_k > 1) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index e1baa85f9..b926a578a 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -131,7 +131,7 @@ ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in A // Load the slope matrix, indexed by Q's dimension 2. ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2) { - const uint32_t h = iq2 + (r & (p.gqa_ratio - 1)); + const uint32_t h = iq2 + (r % p.gqa_ratio); const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1); const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 3a5741c8d..1ee742894 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -4532,7 +4532,9 @@ static std::vector> make_test_cases_perf() { for (int kv : { 4096, 8192, 16384, }) { for (int hs : { 64, 128, }) { - test_cases.emplace_back(new test_flash_attn_ext(hs, hs, 8, 4, kv, 1, true, 0, 0, GGML_PREC_F32, GGML_TYPE_F16)); + for (int nr : { 1, 4, }) { + test_cases.emplace_back(new test_flash_attn_ext(hs, hs, 8, nr, kv, 1, true, 0, 0, GGML_PREC_F32, GGML_TYPE_F16)); + } } } From 12b17501e6015ffe568ac54fdf08e6580833bf1b Mon Sep 17 00:00:00 2001 From: kimminsu <80271594+kimminsu38oo@users.noreply.github.com> Date: Thu, 17 Apr 2025 06:25:57 +0900 Subject: [PATCH 2/6] opencl: fix incorrect local_size index in profiling log (#12868) --- ggml/src/ggml-opencl/ggml-opencl.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index fa40abc33..05a2f4e63 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -1521,7 +1521,7 @@ static void ggml_cl2_free(void) { info.cmd_complete_duration_ns/1.e6f, info.cmd_total_duration_ns/1.e6f, info.global_size[0], info.global_size[1], info.global_size[2], - info.local_size[0], info.local_size[2], info.local_size[2], + info.local_size[0], info.local_size[1], info.local_size[2], info.output_size[0], info.output_size[1], info.output_size[2], info.output_size[3]); } fclose(fperf); From 971f245b3b5f3f55991bb779cb541b00f82eea1d Mon Sep 17 00:00:00 2001 From: Mikko Juola Date: Thu, 17 Apr 2025 01:37:05 -0700 Subject: [PATCH 3/6] llama : recognize IBM Granite 3.3 FIM tokens (#12988) The Granite's FIM tokens are very similar to Qwen's; it's just that they use underscore instead of a dash. So for example instead of . Opening up tokenizer_config.json in ibm-granite/granite-3.3-8b-base shows: ``` "", "", "", "", ... "", ``` --- src/llama-vocab.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index 464ff01e0..480605173 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -1841,6 +1841,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { if (false || t.first == "<|fim_prefix|>" // Qwen || t.first == "" + || t.first == "" // Granite || t.first == "<|fim▁begin|>" // DeepSeek || t.first == "
"
                         || t.first == "▁
"          // CodeLlama
@@ -1859,6 +1860,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                 if (false
                         || t.first == "<|fim_suffix|>" // Qwen
                         || t.first == ""
+                        || t.first == ""   // Granite
                         || t.first == "<|fim▁hole|>" // DeepSeek
                         || t.first == ""
                         || t.first == "▁"         // CodeLlama
@@ -1877,6 +1879,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                 if (false
                         || t.first == "<|fim_middle|>" // Qwen
                         || t.first == ""
+                        || t.first == ""   // Granite
                         || t.first == "<|fim▁end|>"  // DeepSeek
                         || t.first == ""
                         || t.first == "▁"         // CodeLlama
@@ -1895,6 +1898,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                 if (false
                         || t.first == "<|fim_pad|>" // Qwen
                         || t.first == ""
+                        || t.first == ""   // Granite
                         || t.first == ""
                         ) {
                     special_fim_pad_id = t.second;
@@ -1913,6 +1917,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                         || t.first == "<|repo_name|>"
                         || t.first == ""
                         || t.first == ""
+                        || t.first == ""    // Granite
                         ) {
                     special_fim_rep_id = t.second;
                     if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {

From 7a395f67a7a02bb361d944b816d6e933889e28e1 Mon Sep 17 00:00:00 2001
From: hipudding 
Date: Thu, 17 Apr 2025 20:34:16 +0800
Subject: [PATCH 4/6] CANN: Add support for async operator submission (#12864)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Submit operators using asynchronous threads to improve performance.

Use the environment variable GGML_CANN_ASYNC_MODE to control whether
asynchronous submission is enabled. It is disabled by default.

Testing shows a 10%–20% performance improvement in scenarios with
small parameter sizes, especially in quantized models.
---
 ggml/src/ggml-cann/aclnn_ops.cpp | 440 ++++++++++++-------------------
 ggml/src/ggml-cann/aclnn_ops.h   | 324 +++++++++++++++++++----
 ggml/src/ggml-cann/common.h      | 136 +++++++++-
 ggml/src/ggml-cann/ggml-cann.cpp |  60 ++---
 4 files changed, 604 insertions(+), 356 deletions(-)

diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp
index 2c6737ea8..67c0223c0 100644
--- a/ggml/src/ggml-cann/aclnn_ops.cpp
+++ b/ggml/src/ggml-cann/aclnn_ops.cpp
@@ -103,9 +103,7 @@ void ggml_cann_unary_op(
     aclTensor* acl_dst = ggml_cann_create_tensor(dst);
 
     unary_op(ctx, acl_src, acl_dst);
-
-    ACL_CHECK(aclDestroyTensor(acl_src));
-    ACL_CHECK(aclDestroyTensor(acl_dst));
+    ggml_cann_release_resources(ctx, acl_src, acl_dst);
 }
 
 /**
@@ -123,8 +121,8 @@ static void aclnn_repeat(ggml_backend_cann_context& ctx, aclTensor* acl_src,
     // repeat tensor along each dim with repeat_array
     aclIntArray* repeats = aclCreateIntArray(repeat_array, GGML_MAX_DIMS);
 
-    GGML_CANN_CALL_ACLNN_OP(Repeat, acl_src, repeats, acl_dst);
-    ACL_CHECK(aclDestroyIntArray(repeats));
+    GGML_CANN_CALL_ACLNN_OP(ctx, Repeat, acl_src, repeats, acl_dst);
+    ggml_cann_release_resources(ctx, repeats);
 }
 
 /**
@@ -142,7 +140,7 @@ static void aclnn_repeat(ggml_backend_cann_context& ctx, aclTensor* acl_src,
  */
 static void aclnn_cast(ggml_backend_cann_context& ctx, aclTensor* acl_src,
     aclTensor* acl_dst, aclDataType cast_data_type) {
-    GGML_CANN_CALL_ACLNN_OP(Cast, acl_src, cast_data_type, acl_dst);
+    GGML_CANN_CALL_ACLNN_OP(ctx, Cast, acl_src, cast_data_type, acl_dst);
 }
 
 void ggml_cann_repeat(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
@@ -156,8 +154,7 @@ void ggml_cann_repeat(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
                               dst->ne[1] / src->ne[1], dst->ne[0] / src->ne[0]};
 
     aclnn_repeat(ctx, acl_src, acl_dst, repeatsArray);
-    ACL_CHECK(aclDestroyTensor(acl_src));
-    ACL_CHECK(aclDestroyTensor(acl_dst));
+    ggml_cann_release_resources(ctx, acl_src, acl_dst);
 }
 
 void aclnn_add(ggml_backend_cann_context& ctx, aclTensor* acl_src0,
@@ -165,10 +162,10 @@ void aclnn_add(ggml_backend_cann_context& ctx, aclTensor* acl_src0,
     float alphaValue = 1.0f;
     aclScalar* alpha = aclCreateScalar(&alphaValue, aclDataType::ACL_FLOAT);
     if (acl_dst != nullptr)
-        GGML_CANN_CALL_ACLNN_OP(Add, acl_src0, acl_src1, alpha, acl_dst);
+        GGML_CANN_CALL_ACLNN_OP(ctx, Add, acl_src0, acl_src1, alpha, acl_dst);
     else
-        GGML_CANN_CALL_ACLNN_OP(InplaceAdd, acl_src0, acl_src1, alpha);
-    ACL_CHECK(aclDestroyScalar(alpha));
+        GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdd, acl_src0, acl_src1, alpha);
+    ggml_cann_release_resources(ctx, alpha);
 }
 
 void aclnn_sub(ggml_backend_cann_context& ctx, aclTensor* acl_src0,
@@ -176,26 +173,26 @@ void aclnn_sub(ggml_backend_cann_context& ctx, aclTensor* acl_src0,
     float alphaValue = 1.0f;
     aclScalar* alpha = aclCreateScalar(&alphaValue, aclDataType::ACL_FLOAT);
     if (acl_dst != nullptr)
-        GGML_CANN_CALL_ACLNN_OP(Sub, acl_src0, acl_src1, alpha, acl_dst);
+        GGML_CANN_CALL_ACLNN_OP(ctx, Sub, acl_src0, acl_src1, alpha, acl_dst);
     else
-        GGML_CANN_CALL_ACLNN_OP(InplaceSub, acl_src0, acl_src1, alpha);
-    ACL_CHECK(aclDestroyScalar(alpha));
+        GGML_CANN_CALL_ACLNN_OP(ctx, InplaceSub, acl_src0, acl_src1, alpha);
+    ggml_cann_release_resources(ctx, alpha);
 }
 
 void aclnn_mul(ggml_backend_cann_context& ctx, aclTensor* acl_src,
     aclTensor* acl_other, aclTensor* acl_dst) {
     if (acl_dst != nullptr)
-        GGML_CANN_CALL_ACLNN_OP(Mul, acl_src, acl_other, acl_dst);
+        GGML_CANN_CALL_ACLNN_OP(ctx, Mul, acl_src, acl_other, acl_dst);
     else
-        GGML_CANN_CALL_ACLNN_OP(InplaceMul, acl_src, acl_other);
+        GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMul, acl_src, acl_other);
 }
 
 void aclnn_div(ggml_backend_cann_context& ctx, aclTensor* acl_src,
     aclTensor* acl_other, aclTensor* acl_dst) {
     if (acl_dst != nullptr)
-        GGML_CANN_CALL_ACLNN_OP(Div, acl_src, acl_other, acl_dst);
+        GGML_CANN_CALL_ACLNN_OP(ctx, Div, acl_src, acl_other, acl_dst);
     else
-        GGML_CANN_CALL_ACLNN_OP(InplaceDiv, acl_src, acl_other);
+        GGML_CANN_CALL_ACLNN_OP(ctx, InplaceDiv, acl_src, acl_other);
 }
 
 /**
@@ -224,11 +221,11 @@ static void aclnn_muls(ggml_backend_cann_context& ctx, aclTensor* acl_src,
     float scale, aclTensor* acl_dst, bool inplace) {
     aclScalar* acl_scale = aclCreateScalar(&scale, aclDataType::ACL_FLOAT);
     if (inplace) {
-        GGML_CANN_CALL_ACLNN_OP(InplaceMuls, acl_src, acl_scale);
+        GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_src, acl_scale);
     } else {
-        GGML_CANN_CALL_ACLNN_OP(Muls, acl_src, acl_scale, acl_dst);
+        GGML_CANN_CALL_ACLNN_OP(ctx, Muls, acl_src, acl_scale, acl_dst);
     }
-    ACL_CHECK(aclDestroyScalar(acl_scale));
+    ggml_cann_release_resources(ctx, acl_scale);
 }
 
 void ggml_cann_leaky_relu(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
@@ -245,11 +242,8 @@ void ggml_cann_leaky_relu(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
     aclScalar* acl_negative_slope =
         aclCreateScalar(&negative_slope, aclDataType::ACL_FLOAT);
 
-    GGML_CANN_CALL_ACLNN_OP(LeakyRelu, acl_src, acl_negative_slope, acl_dst);
-
-    ACL_CHECK(aclDestroyScalar(acl_negative_slope));
-    ACL_CHECK(aclDestroyTensor(acl_src));
-    ACL_CHECK(aclDestroyTensor(acl_dst));
+    GGML_CANN_CALL_ACLNN_OP(ctx, LeakyRelu, acl_src, acl_negative_slope, acl_dst);
+    ggml_cann_release_resources(ctx, acl_negative_slope, acl_src, acl_dst);
 }
 
 /**
@@ -265,7 +259,7 @@ void ggml_cann_leaky_relu(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
 static void aclnn_concat(ggml_backend_cann_context& ctx,
                          aclTensorList* tensorList, aclTensor* acl_dst,
                          int64_t concat_dim) {
-    GGML_CANN_CALL_ACLNN_OP(Cat, tensorList, concat_dim, acl_dst);
+    GGML_CANN_CALL_ACLNN_OP(ctx, Cat, tensorList, concat_dim, acl_dst);
 }
 
 void ggml_cann_concat(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
@@ -281,11 +275,10 @@ void ggml_cann_concat(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
     int32_t acl_dim = 3 - dim;
 
     aclTensor* tensors[] = {acl_src0, acl_src1};
-    aclTensorList* tensorList = aclCreateTensorList(tensors, 2);
-    aclnn_concat(ctx, tensorList, acl_dst, acl_dim);
+    aclTensorList* tensor_list = aclCreateTensorList(tensors, 2);
+    aclnn_concat(ctx, tensor_list, acl_dst, acl_dim);
 
-    ACL_CHECK(aclDestroyTensorList(tensorList));
-    ACL_CHECK(aclDestroyTensor(acl_dst));
+    ggml_cann_release_resources(ctx, tensor_list, acl_dst);
 }
 
 /**
@@ -315,10 +308,8 @@ static void aclnn_arange(ggml_backend_cann_context& ctx, aclTensor* acl_dst,
     aclScalar* acl_end = aclCreateScalar(&stop, aclDataType::ACL_FLOAT);
     aclScalar* acl_step = aclCreateScalar(&step, aclDataType::ACL_FLOAT);
 
-    GGML_CANN_CALL_ACLNN_OP(Arange, acl_start, acl_end, acl_step, acl_dst);
-    ACL_CHECK(aclDestroyScalar(acl_start));
-    ACL_CHECK(aclDestroyScalar(acl_end));
-    ACL_CHECK(aclDestroyScalar(acl_step));
+    GGML_CANN_CALL_ACLNN_OP(ctx, Arange, acl_start, acl_end, acl_step, acl_dst);
+    ggml_cann_release_resources(ctx, acl_start, acl_end, acl_step);
 }
 
 void ggml_cann_arange(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
@@ -335,7 +326,7 @@ void ggml_cann_arange(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
     memcpy(&step, (float*)dst->op_params + 2, sizeof(float));
 
     aclnn_arange(ctx, acl_dst, start, stop, step, n_elements);
-    ACL_CHECK(aclDestroyTensor(acl_dst));
+    ggml_cann_release_resources(ctx, acl_dst);
 }
 
 void ggml_cann_clamp(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
@@ -352,11 +343,8 @@ void ggml_cann_clamp(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
     aclScalar* acl_min = aclCreateScalar(&min, aclDataType::ACL_FLOAT);
     aclScalar* acl_max = aclCreateScalar(&max, aclDataType::ACL_FLOAT);
 
-    GGML_CANN_CALL_ACLNN_OP(Clamp, acl_src, acl_min, acl_max, acl_dst);
-    ACL_CHECK(aclDestroyScalar(acl_min));
-    ACL_CHECK(aclDestroyScalar(acl_max));
-    ACL_CHECK(aclDestroyTensor(acl_src));
-    ACL_CHECK(aclDestroyTensor(acl_dst));
+    GGML_CANN_CALL_ACLNN_OP(ctx, Clamp, acl_src, acl_min, acl_max, acl_dst);
+    ggml_cann_release_resources(ctx, acl_min, acl_max, acl_src, acl_dst);
 }
 
 void ggml_cann_scale(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
@@ -370,10 +358,8 @@ void ggml_cann_scale(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
     aclTensor* acl_src = ggml_cann_create_tensor(src);
     aclTensor* acl_dst = ggml_cann_create_tensor(dst);
 
-    GGML_CANN_CALL_ACLNN_OP(Muls, acl_src, scale, acl_dst);
-    ACL_CHECK(aclDestroyScalar(scale));
-    ACL_CHECK(aclDestroyTensor(acl_src));
-    ACL_CHECK(aclDestroyTensor(acl_dst));
+    GGML_CANN_CALL_ACLNN_OP(ctx, Muls, acl_src, scale, acl_dst);
+    ggml_cann_release_resources(ctx, scale, acl_src, acl_dst);
 }
 
 void ggml_cann_argsort(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
@@ -388,12 +374,10 @@ void ggml_cann_argsort(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
     aclTensor* tmp_tensor =
         ggml_cann_create_tensor(buffer, ACL_INT64, ggml_type_size(dst->type),
                                 dst->ne, dst->nb, GGML_MAX_DIMS);
-    GGML_CANN_CALL_ACLNN_OP(Argsort, acl_src, -1, (order == GGML_SORT_ORDER_DESC ? true : false),
+    GGML_CANN_CALL_ACLNN_OP(ctx, Argsort, acl_src, -1, (order == GGML_SORT_ORDER_DESC ? true : false),
                       tmp_tensor);
-    GGML_CANN_CALL_ACLNN_OP(Cast, tmp_tensor, ggml_cann_type_mapping(dst->type), acl_dst);
-    ACL_CHECK(aclDestroyTensor(acl_src));
-    ACL_CHECK(aclDestroyTensor(tmp_tensor));
-    ACL_CHECK(aclDestroyTensor(acl_dst));
+    GGML_CANN_CALL_ACLNN_OP(ctx, Cast, tmp_tensor, ggml_cann_type_mapping(dst->type), acl_dst);
+    ggml_cann_release_resources(ctx, acl_src, tmp_tensor, acl_dst);
 }
 
 void ggml_cann_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
@@ -407,11 +391,9 @@ void ggml_cann_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
 
     std::vector normData = {dst->ne[0]};
     aclIntArray* norm = aclCreateIntArray(normData.data(), normData.size());
-    GGML_CANN_CALL_ACLNN_OP(LayerNorm, acl_src, norm, nullptr, nullptr,
+    GGML_CANN_CALL_ACLNN_OP(ctx, LayerNorm, acl_src, norm, nullptr, nullptr,
                     eps, acl_dst, nullptr, nullptr);
-    ACL_CHECK(aclDestroyIntArray(norm));
-    ACL_CHECK(aclDestroyTensor(acl_src));
-    ACL_CHECK(aclDestroyTensor(acl_dst));
+    ggml_cann_release_resources(ctx, norm, acl_src, acl_dst);
 }
 
 void ggml_cann_group_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
@@ -441,12 +423,9 @@ void ggml_cann_group_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
     aclTensor* acl_rstd_out = ggml_cann_create_tensor(
         (char*)buffer + n_bytes, ACL_FLOAT, type_size, ne, nb, ACL_FORMAT_ND);
 
-    GGML_CANN_CALL_ACLNN_OP(GroupNorm, acl_src, nullptr, nullptr, N, C, HxW, n_groups, eps,
+    GGML_CANN_CALL_ACLNN_OP(ctx, GroupNorm, acl_src, nullptr, nullptr, N, C, HxW, n_groups, eps,
         acl_dst, acl_mean_out, acl_rstd_out);
-    ACL_CHECK(aclDestroyTensor(acl_src));
-    ACL_CHECK(aclDestroyTensor(acl_dst));
-    ACL_CHECK(aclDestroyTensor(acl_mean_out));
-    ACL_CHECK(aclDestroyTensor(acl_rstd_out));
+    ggml_cann_release_resources(ctx, acl_src, acl_dst, acl_mean_out, acl_rstd_out);
 }
 
 void ggml_cann_acc(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
@@ -471,19 +450,17 @@ void ggml_cann_acc(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
 
     if (!inplace) {
         size_t cpy_size = ggml_nbytes(dst);
-        ACL_CHECK(aclrtMemcpyAsync(dst->data, cpy_size, src0->data, cpy_size,
-                                   ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream()));
+        ggml_cann_async_memcpy(ctx, dst->data, src0->data, cpy_size,
+            ACL_MEMCPY_DEVICE_TO_DEVICE);
         aclTensor* acl_src0 = ggml_cann_create_tensor(
             src0, src1->ne, src0->nb, GGML_MAX_DIMS, ACL_FORMAT_ND, offset);
 
-        GGML_CANN_CALL_ACLNN_OP(Add, acl_src0, acl_src1, alpha, acl_dst);
-        ACL_CHECK(aclDestroyTensor(acl_src0));
+        GGML_CANN_CALL_ACLNN_OP(ctx, Add, acl_src0, acl_src1, alpha, acl_dst);
+        ggml_cann_release_resources(ctx, acl_src0);
     } else {
-        GGML_CANN_CALL_ACLNN_OP(InplaceAdd, acl_dst, acl_src1, alpha);
+        GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdd, acl_dst, acl_src1, alpha);
     }
-
-    ACL_CHECK(aclDestroyTensor(acl_src1));
-    ACL_CHECK(aclDestroyTensor(acl_dst));
+    ggml_cann_release_resources(ctx, acl_src1, acl_dst);
 }
 
 /**
@@ -496,7 +473,6 @@ void ggml_cann_acc(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
  * @param dim An array of dimension indices.
  * @param dim_size The number of dimensions.
  */
-
 static void aclnn_reduce_sum(ggml_backend_cann_context& ctx, ggml_tensor* dst,
                              int64_t* dim, size_t dim_size) {
     GGML_ASSERT(dst->ne[0] == 1);
@@ -505,11 +481,9 @@ static void aclnn_reduce_sum(ggml_backend_cann_context& ctx, ggml_tensor* dst,
     aclTensor* acl_dst = ggml_cann_create_tensor(dst);
     aclIntArray* reduce_dims = aclCreateIntArray(dim, dim_size);
 
-    GGML_CANN_CALL_ACLNN_OP(ReduceSum, acl_src, reduce_dims, true,
+    GGML_CANN_CALL_ACLNN_OP(ctx, ReduceSum, acl_src, reduce_dims, true,
                       ggml_cann_type_mapping(dst->type), acl_dst);
-    ACL_CHECK(aclDestroyTensor(acl_src));
-    ACL_CHECK(aclDestroyTensor(acl_dst));
-    ACL_CHECK(aclDestroyIntArray(reduce_dims));
+    ggml_cann_release_resources(ctx, acl_src, acl_dst, reduce_dims);
 }
 
 void ggml_cann_sum_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
@@ -533,10 +507,8 @@ void ggml_cann_upsample_nearest2d(ggml_backend_cann_context& ctx,
     std::vector output_size{dst->ne[1], dst->ne[0]};
     auto output_size_array = aclCreateIntArray(output_size.data(), 2);
 
-    GGML_CANN_CALL_ACLNN_OP(UpsampleNearest2d, acl_src, output_size_array, acl_dst);
-    ACL_CHECK(aclDestroyIntArray(output_size_array));
-    ACL_CHECK(aclDestroyTensor(acl_src));
-    ACL_CHECK(aclDestroyTensor(acl_dst));
+    GGML_CANN_CALL_ACLNN_OP(ctx, UpsampleNearest2d, acl_src, output_size_array, acl_dst);
+    ggml_cann_release_resources(ctx, acl_src, acl_dst, output_size_array);
 }
 
 /**
@@ -559,9 +531,8 @@ static void aclnn_pad(ggml_backend_cann_context& ctx, aclTensor* acl_src,
     aclIntArray* acl_pad = aclCreateIntArray(paddings, GGML_MAX_DIMS * 2);
     aclScalar* acl_value = aclCreateScalar(&value, aclDataType::ACL_FLOAT);
 
-    GGML_CANN_CALL_ACLNN_OP(ConstantPadNd, acl_src, acl_pad, acl_value, acl_dst);
-    ACL_CHECK(aclDestroyIntArray(acl_pad));
-    ACL_CHECK(aclDestroyScalar(acl_value));
+    GGML_CANN_CALL_ACLNN_OP(ctx, ConstantPadNd, acl_src, acl_pad, acl_value, acl_dst);
+    ggml_cann_release_resources(ctx, acl_pad, acl_value);
 }
 
 void ggml_cann_pad(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
@@ -577,9 +548,7 @@ void ggml_cann_pad(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
         0, dst->ne[0] - src->ne[0], 0, dst->ne[1] - src->ne[1],
         0, dst->ne[2] - src->ne[2], 0, dst->ne[3] - src->ne[3]};
     aclnn_pad(ctx, acl_src, acl_dst, paddings);
-
-    ACL_CHECK(aclDestroyTensor(acl_dst));
-    ACL_CHECK(aclDestroyTensor(acl_src));
+    ggml_cann_release_resources(ctx, acl_src, acl_dst);
 }
 
 /**
@@ -629,14 +598,11 @@ static void ggml_cann_avg_pool2d(ggml_backend_cann_context& ctx,
     cube_math_type = 1;
 #endif
 
-    GGML_CANN_CALL_ACLNN_OP(AvgPool2d, acl_src, kernel_size, strides, paddings_avg,
+    GGML_CANN_CALL_ACLNN_OP(ctx, AvgPool2d, acl_src, kernel_size, strides, paddings_avg,
                     ceil_mode, count_include_pad, divisor_override,
                     cube_math_type, acl_dst);
-    ACL_CHECK(aclDestroyTensor(acl_src));
-    ACL_CHECK(aclDestroyTensor(acl_dst));
-    ACL_CHECK(aclDestroyIntArray(kernel_size));
-    ACL_CHECK(aclDestroyIntArray(strides));
-    ACL_CHECK(aclDestroyIntArray(paddings_avg));
+    ggml_cann_release_resources(ctx, acl_src, acl_dst, kernel_size, strides,
+                                paddings_avg);
 }
 
 /**
@@ -704,15 +670,10 @@ static void ggml_cann_max_pool2d(ggml_backend_cann_context& ctx,
 
     bool ceil_mode = false;
     int64_t auto_pads = 0;
-    GGML_CANN_CALL_ACLNN_OP(MaxPool, tmp_tensor, kernel_size, strides, auto_pads,
+    GGML_CANN_CALL_ACLNN_OP(ctx, MaxPool, tmp_tensor, kernel_size, strides, auto_pads,
                     paddings_max, dilations, ceil_mode, acl_dst);
-    ACL_CHECK(aclDestroyTensor(acl_src));
-    ACL_CHECK(aclDestroyTensor(acl_dst));
-    ACL_CHECK(aclDestroyTensor(tmp_tensor));
-    ACL_CHECK(aclDestroyIntArray(kernel_size));
-    ACL_CHECK(aclDestroyIntArray(strides));
-    ACL_CHECK(aclDestroyIntArray(paddings_max));
-    ACL_CHECK(aclDestroyIntArray(dilations));
+    ggml_cann_release_resources(ctx, acl_src, acl_dst, tmp_tensor, kernel_size,
+                                strides, paddings_max, dilations);
 }
 
 void ggml_cann_pool2d(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
@@ -743,7 +704,7 @@ void ggml_cann_pool2d(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
  */
 static void cann_copy(ggml_backend_cann_context& ctx, aclTensor* acl_src,
                       aclTensor* acl_dst) {
-    GGML_CANN_CALL_ACLNN_OP(InplaceCopy, acl_dst, acl_src);
+    GGML_CANN_CALL_ACLNN_OP(ctx, InplaceCopy, acl_dst, acl_src);
 }
 
 void ggml_cann_dup(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
@@ -761,9 +722,8 @@ void ggml_cann_dup(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
         if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst)) {
             if (dst->type == src0->type) {
                 size_t cpy_size = ggml_nbytes(dst);
-                ACL_CHECK(aclrtMemcpyAsync(
-                    dst->data, cpy_size, src0->data, cpy_size,
-                    ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream()));
+                ggml_cann_async_memcpy(ctx, dst->data, src0->data, cpy_size,
+                    ACL_MEMCPY_DEVICE_TO_DEVICE);
                 return;
             } else {
                 ggml_cann_pool_alloc src_buffer_allocator(
@@ -782,10 +742,9 @@ void ggml_cann_dup(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
 
                 aclnn_cast(ctx, acl_src, src_trans_tensor, ggml_cann_type_mapping(dst->type));
                 size_t cpy_size = ggml_nbytes(dst);
-                ACL_CHECK(aclrtMemcpyAsync(
-                    dst->data, cpy_size, src_trans_buffer, cpy_size,
-                    ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream()));
-                ACL_CHECK(aclDestroyTensor(src_trans_tensor));
+                ggml_cann_async_memcpy(ctx, dst->data, src_trans_buffer, cpy_size,
+                    ACL_MEMCPY_DEVICE_TO_DEVICE);
+                ggml_cann_release_resources(ctx, src_trans_tensor);
                 return;
             }
         } else if (ggml_is_contiguous(dst)) {
@@ -805,18 +764,15 @@ void ggml_cann_dup(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
             aclnn_cast(ctx, acl_src, src_trans_tensor, ggml_cann_type_mapping(dst->type));
 
             size_t cpy_size = ggml_nbytes(dst);
-            ACL_CHECK(aclrtMemcpyAsync(dst->data, cpy_size, src_trans_buffer,
-                                       cpy_size, ACL_MEMCPY_DEVICE_TO_DEVICE,
-                                       ctx.stream()));
-            ACL_CHECK(aclDestroyTensor(src_trans_tensor));
+            ggml_cann_async_memcpy(ctx, dst->data, src_trans_buffer, cpy_size,
+                ACL_MEMCPY_DEVICE_TO_DEVICE);
+            ggml_cann_release_resources(ctx, src_trans_tensor);
             return;
         } else {
             GGML_ABORT("Unsupport dst is not tontiguous.");
         }
     }
-
-    ACL_CHECK(aclDestroyTensor(acl_src));
-    ACL_CHECK(aclDestroyTensor(acl_dst));
+    ggml_cann_release_resources(ctx, acl_src, acl_dst);
 }
 
 /**
@@ -844,7 +800,7 @@ static aclTensor* aclnn_zero(ggml_backend_cann_context& ctx, void* buffer,
         nb[i] = nb[i - 1] * ne[i - 1];
     }
 
-    ACL_CHECK(aclrtMemsetAsync(buffer, n_bytes, 0, n_bytes, ctx.stream()));
+    ggml_cann_async_memset(ctx, buffer, n_bytes, 0);
     aclTensor* zero =
         ggml_cann_create_tensor(buffer, type, type_size, ne, nb, dims);
     return zero;
@@ -877,7 +833,7 @@ static aclTensor* aclnn_values(ggml_backend_cann_context& ctx, void* buffer,
     float alpha_host = 1.0f;
     aclScalar* alpha = aclCreateScalar(&alpha_host, aclDataType::ACL_FLOAT);
     aclScalar* other = aclCreateScalar(&value, aclDataType::ACL_FLOAT);
-    GGML_CANN_CALL_ACLNN_OP(InplaceAdds, acl_tensor, other, alpha);
+    GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdds, acl_tensor, other, alpha);
     return acl_tensor;
 }
 
@@ -903,11 +859,8 @@ void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
         aclnn_zero(ctx, zero_tensor_allocator.get(), zero_tensor_n_bytes,
                    src->ne, GGML_MAX_DIMS, ggml_cann_type_mapping(src->type),
                    ggml_element_size(src));
-    GGML_CANN_CALL_ACLNN_OP(RmsNorm, acl_src, acl_gamma, eps, acl_dst, acl_rstd);
-    ACL_CHECK(aclDestroyTensor(acl_src));
-    ACL_CHECK(aclDestroyTensor(acl_dst));
-    ACL_CHECK(aclDestroyTensor(acl_gamma));
-    ACL_CHECK(aclDestroyTensor(acl_rstd));
+    GGML_CANN_CALL_ACLNN_OP(ctx, RmsNorm, acl_src, acl_gamma, eps, acl_dst, acl_rstd);
+    ggml_cann_release_resources(ctx, acl_src, acl_dst, acl_gamma, acl_rstd);
 }
 
 // TODO: performace is low.
@@ -933,13 +886,10 @@ void ggml_cann_diag_mask(ggml_backend_cann_context& ctx, ggml_tensor* dst,
     float alphaValue = 1.0f;
     alpha = aclCreateScalar(&alphaValue, aclDataType::ACL_FLOAT);
 
-    GGML_CANN_CALL_ACLNN_OP(InplaceTriu, mask_tensor, n_past + 1);
-    GGML_CANN_CALL_ACLNN_OP(Tril, acl_src, n_past + 1, acl_dst);
-    GGML_CANN_CALL_ACLNN_OP(InplaceAdd, acl_dst, mask_tensor, alpha);
-    ACL_CHECK(aclDestroyScalar(alpha));
-    ACL_CHECK(aclDestroyTensor(mask_tensor));
-    ACL_CHECK(aclDestroyTensor(acl_src));
-    ACL_CHECK(aclDestroyTensor(acl_dst));
+    GGML_CANN_CALL_ACLNN_OP(ctx, InplaceTriu, mask_tensor, n_past + 1);
+    GGML_CANN_CALL_ACLNN_OP(ctx, Tril, acl_src, n_past + 1, acl_dst);
+    GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdd, acl_dst, mask_tensor, alpha);
+    ggml_cann_release_resources(ctx, alpha, acl_src, acl_dst, mask_tensor);
 }
 
 /**
@@ -960,7 +910,8 @@ void ggml_cann_diag_mask(ggml_backend_cann_context& ctx, ggml_tensor* dst,
 static void aclnn_permute(ggml_backend_cann_context& ctx, aclTensor* acl_src,
                           aclTensor* acl_dst, int64_t* new_dim, uint64_t dims) {
     aclIntArray* acl_dims = aclCreateIntArray(new_dim, dims);
-    GGML_CANN_CALL_ACLNN_OP(Permute, acl_src, acl_dims, acl_dst);
+    GGML_CANN_CALL_ACLNN_OP(ctx, Permute, acl_src, acl_dims, acl_dst);
+    ggml_cann_release_resources(ctx, acl_dims);
 }
 
 static void ggml_cann_im2col_2d_post_process(ggml_backend_cann_context& ctx,
@@ -981,8 +932,7 @@ static void ggml_cann_im2col_2d_post_process(ggml_backend_cann_context& ctx,
         aclnn_permute(ctx, tmp_im2col_tensor, acl_dst, permute_dim, 3);
     }
 
-    // release
-    ACL_CHECK(aclDestroyTensor(acl_dst));
+    ggml_cann_release_resources(ctx, acl_dst);
 }
 
 static void ggml_cann_im2col_1d_post_process(
@@ -1004,7 +954,6 @@ static void ggml_cann_im2col_1d_post_process(
 
     // Permute: [N, IC * KH * KW, OW * OH] ->
     // [N, OW * OH * n_bytes_factor, IC * KH * KW]
-    aclTensor* tmp_permute_tensor = nullptr;
     ggml_cann_pool_alloc tmp_permute_allocator(ctx.pool());
     tmp_permute_allocator.alloc(ggml_nbytes(dst) * n_bytes_factor);
     void* tmp_permute_buffer = tmp_permute_allocator.get();
@@ -1016,7 +965,7 @@ static void ggml_cann_im2col_1d_post_process(
         tmp_permute_nb[i] = tmp_permute_nb[i - 1] * tmp_permute_ne[i - 1];
     }
 
-    tmp_permute_tensor = ggml_cann_create_tensor(
+    aclTensor* tmp_permute_tensor = ggml_cann_create_tensor(
         tmp_permute_buffer, ggml_cann_type_mapping(dst->type),
         ggml_type_size(dst->type), tmp_permute_ne, tmp_permute_nb,
         GGML_MAX_DIMS - 1, ACL_FORMAT_ND);
@@ -1046,9 +995,8 @@ static void ggml_cann_im2col_1d_post_process(
                              c * KH * KW * n_step_w * ggml_type_size(dst->type);
 
             for (int i = 0; i < n_step_w; i++) {
-                ACL_CHECK(aclrtMemcpyAsync(
-                    cur_dst_buffer, size_cpy, cur_permute_buffer, size_cpy,
-                    ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream()));
+                ggml_cann_async_memcpy(ctx, cur_dst_buffer, cur_permute_buffer, size_cpy,
+                    ACL_MEMCPY_DEVICE_TO_DEVICE);
                 cur_dst_buffer =
                     (char*)cur_dst_buffer + KH * KW * ggml_type_size(dst->type);
                 cur_permute_buffer = (char*)cur_permute_buffer +
@@ -1058,13 +1006,11 @@ static void ggml_cann_im2col_1d_post_process(
     } else {
         offset = KH * KW * n_step_w *
                  ggml_type_size(dst->type);  // equal to ggml_nbytes(dst)
-        ACL_CHECK(aclrtMemcpyAsync(dst->data, offset,
-                                   (char*)tmp_permute_buffer + offset, offset,
-                                   ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream()));
+        ggml_cann_async_memcpy(ctx, dst->data, (char*)tmp_permute_buffer + offset, offset,
+            ACL_MEMCPY_DEVICE_TO_DEVICE);
     }
 
-    // release
-    ACL_CHECK(aclDestroyTensor(tmp_permute_tensor));
+    ggml_cann_release_resources(ctx, tmp_permute_tensor);
 }
 
 void ggml_cann_im2col(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
@@ -1126,7 +1072,7 @@ void ggml_cann_im2col(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
     auto* dilations = aclCreateIntArray(dilation_size.data(), 2);
     auto* paddings = aclCreateIntArray(padding_dims.data(), 2);
     auto* strides = aclCreateIntArray(stride_dims.data(), 2);
-    GGML_CANN_CALL_ACLNN_OP(Im2col, acl_src1, kernel_size, dilations,
+    GGML_CANN_CALL_ACLNN_OP(ctx, Im2col, acl_src1, kernel_size, dilations,
                     paddings, strides, tmp_im2col_tensor);
 
     // Cast if dst is f16.
@@ -1160,14 +1106,8 @@ void ggml_cann_im2col(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
                                          tmp_im2col_tensor, im2col_op_params);
     }
 
-    // release
-    ACL_CHECK(aclDestroyTensor(acl_src1));
-    ACL_CHECK(aclDestroyTensor(tmp_im2col_tensor));
-    ACL_CHECK(aclDestroyTensor(tmp_cast_tensor));
-    ACL_CHECK(aclDestroyIntArray(kernel_size));
-    ACL_CHECK(aclDestroyIntArray(dilations));
-    ACL_CHECK(aclDestroyIntArray(paddings));
-    ACL_CHECK(aclDestroyIntArray(strides));
+    ggml_cann_release_resources(ctx, acl_src1, tmp_im2col_tensor, tmp_cast_tensor,
+        kernel_size, dilations, paddings, strides);
 }
 
 /**
@@ -1184,17 +1124,17 @@ void ggml_cann_im2col(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
  * @param acl_src The tensor on which the exponential function will be applied.
  */
 static void aclnn_exp(ggml_backend_cann_context& ctx, aclTensor* acl_src) {
-    GGML_CANN_CALL_ACLNN_OP(InplaceExp, acl_src);
+    GGML_CANN_CALL_ACLNN_OP(ctx, InplaceExp, acl_src);
 }
 
 void aclnn_cos(ggml_backend_cann_context& ctx, aclTensor* acl_src,
                       aclTensor* acl_dst) {
-    GGML_CANN_CALL_ACLNN_OP(Cos, acl_src, acl_dst);
+    GGML_CANN_CALL_ACLNN_OP(ctx, Cos, acl_src, acl_dst);
 }
 
 void aclnn_sin(ggml_backend_cann_context& ctx, aclTensor* acl_src,
                       aclTensor* acl_dst) {
-    GGML_CANN_CALL_ACLNN_OP(Sin, acl_src, acl_dst);
+    GGML_CANN_CALL_ACLNN_OP(ctx, Sin, acl_src, acl_dst);
 }
 
 void ggml_cann_timestep_embedding(ggml_backend_cann_context& ctx,
@@ -1243,13 +1183,13 @@ void ggml_cann_timestep_embedding(ggml_backend_cann_context& ctx,
 
     ggml_cann_pool_alloc permute_allocator(ctx.pool(), ggml_nbytes(src));
     void* tmp_permute_buffer = permute_allocator.get();
-    aclTensor* tmp_permute_tenosr = ggml_cann_create_tensor(
+    aclTensor* tmp_permute_tensor = ggml_cann_create_tensor(
         tmp_permute_buffer, ggml_cann_type_mapping(src->type),
         ggml_type_size(src->type), tmp_permute_ne, tmp_permute_nb,
         GGML_MAX_DIMS, ACL_FORMAT_ND);
     int64_t permute_dim[] = {0, 1, 3, 2};
     int64_t num_dims = 4;
-    aclnn_permute(ctx, acl_src, tmp_permute_tenosr, permute_dim, num_dims);
+    aclnn_permute(ctx, acl_src, tmp_permute_tensor, permute_dim, num_dims);
 
     // timestep * freq
     int64_t tmp_mul_ne[] = {src->ne[1] * half, src->ne[0], src->ne[2],
@@ -1270,7 +1210,7 @@ void ggml_cann_timestep_embedding(ggml_backend_cann_context& ctx,
         tmp_mul_buffer, ggml_cann_type_mapping(src->type),
         ggml_type_size(src->type), tmp_mul_ne, tmp_mul_nb, GGML_MAX_DIMS,
         ACL_FORMAT_ND);
-    aclnn_mul(ctx, tmp_permute_tenosr, tmp_arange_tensor, tmp_mul_tensor);
+    aclnn_mul(ctx, tmp_permute_tensor, tmp_arange_tensor, tmp_mul_tensor);
 
     // cos
     ggml_cann_pool_alloc cos_allocator(
@@ -1298,17 +1238,13 @@ void ggml_cann_timestep_embedding(ggml_backend_cann_context& ctx,
     int64_t concat_dim = 3;
     aclTensor* acl_dst = ggml_cann_create_tensor(dst);
     aclTensor* tensors[] = {tmp_cos_tensor, tmp_sin_tensor};
-    aclTensorList* tensorList = aclCreateTensorList(tensors, 2);
-    aclnn_concat(ctx, tensorList, acl_dst, concat_dim);
+    aclTensorList* tensor_list = aclCreateTensorList(tensors, 2);
+    aclnn_concat(ctx, tensor_list, acl_dst, concat_dim);
 
     // release
     // segmentation fault when delete both tensorList and his elements.
-    ACL_CHECK(aclDestroyTensorList(tensorList));
-    ACL_CHECK(aclDestroyTensor(acl_src));
-    ACL_CHECK(aclDestroyTensor(tmp_arange_tensor));
-    ACL_CHECK(aclDestroyTensor(tmp_permute_tenosr));
-    ACL_CHECK(aclDestroyTensor(tmp_mul_tensor));
-    ACL_CHECK(aclDestroyTensor(acl_dst));
+    ggml_cann_release_resources(ctx, tensor_list, acl_src, tmp_arange_tensor,
+        tmp_permute_tensor, tmp_mul_tensor, acl_dst);
 }
 
 /**
@@ -1324,8 +1260,8 @@ void ggml_cann_timestep_embedding(ggml_backend_cann_context& ctx,
 static void aclnn_fill_scalar(ggml_backend_cann_context& ctx, float scalar,
                               aclTensor* acl_dst) {
     auto acl_scalar = aclCreateScalar(&scalar, aclDataType::ACL_FLOAT);
-    GGML_CANN_CALL_ACLNN_OP(InplaceFillScalar, acl_dst, acl_scalar);
-    ACL_CHECK(aclDestroyScalar(acl_scalar));
+    GGML_CANN_CALL_ACLNN_OP(ctx, InplaceFillScalar, acl_dst, acl_scalar);
+    ggml_cann_release_resources(ctx, acl_scalar);
 }
 
 /**
@@ -1346,7 +1282,7 @@ static void aclnn_fill_scalar(ggml_backend_cann_context& ctx, float scalar,
  */
 static void aclnn_pow_tensor_tensor(ggml_backend_cann_context& ctx,
                                     aclTensor* acl_dst, aclTensor* acl_exp) {
-    GGML_CANN_CALL_ACLNN_OP(InplacePowTensorTensor, acl_dst, acl_exp);
+    GGML_CANN_CALL_ACLNN_OP(ctx, InplacePowTensorTensor, acl_dst, acl_exp);
 }
 
 /**
@@ -1498,15 +1434,9 @@ static void aclnn_alibi(ggml_backend_cann_context& ctx, aclTensor* acl_src,
 
     // add
     aclnn_add(ctx, tmp_output_tensor, acl_src, acl_dst);
-
-    ACL_CHECK(aclDestroyTensor(tmp_arange1_tensor));
-    ACL_CHECK(aclDestroyTensor(tmp_arange2_tensor));
-    ACL_CHECK(aclDestroyTensor(tmp_mk_base1_tensor));
-    ACL_CHECK(aclDestroyTensor(tmp_mk_base2_tensor));
-    ACL_CHECK(aclDestroyTensor(tmp_mk_base_tensor));
-    ACL_CHECK(aclDestroyTensor(tmp_arange_tensor));
-    ACL_CHECK(aclDestroyTensor(tmp_mk_tensor));
-    ACL_CHECK(aclDestroyTensor(tmp_output_tensor));
+    ggml_cann_release_resources(ctx, tmp_arange1_tensor, tmp_arange2_tensor,
+        tmp_mk_base1_tensor, tmp_mk_base2_tensor, tmp_mk_base_tensor,
+        tmp_arange_tensor, tmp_mk_tensor, tmp_output_tensor);
 }
 
 void ggml_cann_cpy(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
@@ -1529,7 +1459,7 @@ void ggml_cann_cpy(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
  */
 static void aclnn_softmax(ggml_backend_cann_context& ctx, aclTensor* acl_src,
                           int64_t dim, aclTensor* acl_dst) {
-    GGML_CANN_CALL_ACLNN_OP(Softmax, acl_src, dim, acl_dst);
+    GGML_CANN_CALL_ACLNN_OP(ctx, Softmax, acl_src, dim, acl_dst);
 }
 
 void ggml_cann_softmax(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
@@ -1579,8 +1509,7 @@ void ggml_cann_softmax(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
                 src1_fp32_nb, GGML_MAX_DIMS);
             aclTensor* acl_src1 = ggml_cann_create_tensor(src1);
             aclnn_cast(ctx, acl_src1, acl_src1_fp32_tensor, ACL_FLOAT);
-
-            ACL_CHECK(aclDestroyTensor(acl_src1));
+            ggml_cann_release_resources(ctx, acl_src1);
         } else {
             acl_src1_fp32_tensor = ggml_cann_create_tensor(src1);
         }
@@ -1633,17 +1562,13 @@ void ggml_cann_softmax(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
 
         // softmax
         aclnn_softmax(ctx, alibi_output_tensor, 3, acl_dst);
-        ACL_CHECK(aclDestroyTensor(alibi_output_tensor));
+        ggml_cann_release_resources(ctx, alibi_output_tensor);
     } else {
         aclnn_softmax(ctx, acl_input_mul_scale_tensor, 3, acl_dst);
     }
 
-    ACL_CHECK(aclDestroyTensor(acl_src0));
-    ACL_CHECK(aclDestroyTensor(acl_src1_fp32_tensor));
-    ACL_CHECK(aclDestroyTensor(acl_dst));
-    ACL_CHECK(aclDestroyScalar(acl_scale));
-    ACL_CHECK(aclDestroyTensor(acl_input_mul_scale_tensor));
-    ACL_CHECK(aclDestroyTensor(tmp_mask_tensor));
+    ggml_cann_release_resources(ctx, acl_src0, acl_src1_fp32_tensor, acl_dst,
+        acl_scale, acl_input_mul_scale_tensor, tmp_mask_tensor);
 }
 
 /**
@@ -1690,10 +1615,8 @@ static void aclnn_embedding_4d(ggml_backend_cann_context& ctx, void* src_buffer,
                 (char*)dst->data + i * dst->nb[3] + j * dst->nb[2],
                 ggml_cann_type_mapping(dst->type), ggml_element_size(dst),
                 acl_out_ne, acl_out_nb, 2);
-            GGML_CANN_CALL_ACLNN_OP(Embedding, acl_src_tensor, acl_index, acl_out);
-            ACL_CHECK(aclDestroyTensor(acl_src_tensor));
-            ACL_CHECK(aclDestroyTensor(acl_index));
-            ACL_CHECK(aclDestroyTensor(acl_out));
+            GGML_CANN_CALL_ACLNN_OP(ctx, Embedding, acl_src_tensor, acl_index, acl_out);
+            ggml_cann_release_resources(ctx, acl_src_tensor, acl_index, acl_out);
         }
     }
 }
@@ -1724,8 +1647,7 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
             aclnn_cast(ctx, acl_src0, src_trans_tensor, ggml_cann_type_mapping(dst->type));
             aclnn_embedding_4d(ctx, src_trans_buffer, src0->ne,
                                    src_trans_nb, src1, dst);
-            ACL_CHECK(aclDestroyTensor(acl_src0));
-            ACL_CHECK(aclDestroyTensor(src_trans_tensor));
+            ggml_cann_release_resources(ctx, acl_src0, src_trans_tensor);
             break;
         }
         case GGML_TYPE_Q8_0: {
@@ -1787,7 +1709,7 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
             aclnn_embedding_4d(ctx, dequant_buffer_allocator.get(),
                                    dequant_ne, dequant_nb, src1, dst);
 
-            ACL_CHECK(aclDestroyTensor(dequant_tensor));
+            ggml_cann_release_resources(ctx, dequant_tensor);
             break;
         }
         default:
@@ -1815,7 +1737,7 @@ static void aclnn_repeat_interleave(ggml_backend_cann_context& ctx,
                                     aclTensor* acl_src, aclTensor* acl_dst,
                                     int64_t dim, int64_t repeats,
                                     int64_t output_size) {
-    GGML_CANN_CALL_ACLNN_OP(RepeatInterleaveIntWithDim, acl_src, repeats, dim,
+    GGML_CANN_CALL_ACLNN_OP(ctx, RepeatInterleaveIntWithDim, acl_src, repeats, dim,
                   output_size, acl_dst);
 }
 
@@ -1864,21 +1786,19 @@ static void ggml_cann_mat_mul_fp(ggml_backend_cann_context& ctx,
 
     switch (n_dims) {
         case 2:
-            GGML_CANN_CALL_ACLNN_OP(Mm, acl_input_tensor, acl_weight_tensor, acl_dst, 2);
+            GGML_CANN_CALL_ACLNN_OP(ctx, Mm, acl_input_tensor, acl_weight_tensor, acl_dst, 2);
             break;
         case 3:
-            GGML_CANN_CALL_ACLNN_OP(BatchMatMul, acl_input_tensor, acl_weight_tensor, acl_dst, 2);
+            GGML_CANN_CALL_ACLNN_OP(ctx, BatchMatMul, acl_input_tensor, acl_weight_tensor, acl_dst, 2);
             break;
         default:
             // ALLOW_FP32_DOWN_PRECISION, when input is
             // fp32, atlas a2 will transpose it to HFLOAT32.
-            GGML_CANN_CALL_ACLNN_OP(Matmul, acl_input_tensor, acl_weight_tensor, acl_dst, 1);
+            GGML_CANN_CALL_ACLNN_OP(ctx, Matmul, acl_input_tensor, acl_weight_tensor, acl_dst, 1);
             break;
     }
 
-    ACL_CHECK(aclDestroyTensor(acl_weight_tensor));
-    ACL_CHECK(aclDestroyTensor(acl_input_tensor));
-    ACL_CHECK(aclDestroyTensor(acl_dst));
+    ggml_cann_release_resources(ctx, acl_weight_tensor, acl_input_tensor, acl_dst);
 }
 
 /**
@@ -1948,9 +1868,7 @@ static void ggml_cann_mul_mat_quant(ggml_backend_cann_context& ctx,
             input_buffer, ACL_FLOAT16, input_elem_size, input_cast_ne,
             input_cast_nb, GGML_MAX_DIMS);
         aclnn_cast(ctx, acl_src1_tensor, acl_input_tensor, ACL_FLOAT16);
-
-        ACL_CHECK(aclDestroyTensor(acl_input_tensor));
-        ACL_CHECK(aclDestroyTensor(acl_src1_tensor));
+        ggml_cann_release_resources(ctx, acl_input_tensor, acl_src1_tensor);
     }
 
     // output
@@ -2003,13 +1921,11 @@ static void ggml_cann_mul_mat_quant(ggml_backend_cann_context& ctx,
             if (src0->ne[0] > QK8_0) {
                 antiquantGroupSize = QK8_0;
             }
-            GGML_CANN_CALL_ACLNN_OP(WeightQuantBatchMatmulV2, acl_input_tensor,
+            GGML_CANN_CALL_ACLNN_OP(ctx, WeightQuantBatchMatmulV2, acl_input_tensor,
                            acl_weight_tensor, acl_scale_tensor, nullptr,
                            nullptr, nullptr, nullptr, antiquantGroupSize,
                            acl_output_tensor);
-            ACL_CHECK(aclDestroyTensor(acl_weight_tensor));
-            ACL_CHECK(aclDestroyTensor(acl_scale_tensor));
-            ACL_CHECK(aclDestroyTensor(acl_output_tensor));
+            ggml_cann_release_resources(ctx, acl_weight_tensor, acl_scale_tensor, acl_output_tensor);
 
             // other splits
             for (int64_t split = 1; split < split_size; split++) {
@@ -2036,16 +1952,14 @@ static void ggml_cann_mul_mat_quant(ggml_backend_cann_context& ctx,
                     (char*)output_buffer + batch1 * output_stride, ACL_FLOAT16,
                     output_elem_size, output_ne, output_nb, 2, ACL_FORMAT_ND,
                     output_ne_offset);
-                GGML_CANN_CALL_ACLNN_OP(WeightQuantBatchMatmulV2, acl_input_tensor,
+                GGML_CANN_CALL_ACLNN_OP(ctx, WeightQuantBatchMatmulV2, acl_input_tensor,
                                    acl_weight_tensor, acl_scale_tensor, nullptr,
                                    nullptr, nullptr, nullptr, antiquantGroupSize,
                                    acl_output_tensor);
-                ACL_CHECK(aclDestroyTensor(acl_weight_tensor));
-                ACL_CHECK(aclDestroyTensor(acl_scale_tensor));
-                ACL_CHECK(aclDestroyTensor(acl_output_tensor));
+                ggml_cann_release_resources(ctx, acl_weight_tensor, acl_scale_tensor, acl_output_tensor);
             }
 
-            ACL_CHECK(aclDestroyTensor(acl_input_tensor));
+            ggml_cann_release_resources(ctx, acl_input_tensor);
         }
     }
 
@@ -2064,8 +1978,7 @@ static void ggml_cann_mul_mat_quant(ggml_backend_cann_context& ctx,
         aclTensor* acl_dst_tensor = ggml_cann_create_tensor(dst);
         aclnn_cast(ctx, acl_output_tensor, acl_dst_tensor, ggml_cann_type_mapping(dst->type));
 
-        ACL_CHECK(aclDestroyTensor(acl_output_tensor));
-        ACL_CHECK(aclDestroyTensor(acl_dst_tensor));
+        ggml_cann_release_resources(ctx, acl_output_tensor, acl_dst_tensor);
     }
 }
 
@@ -2106,9 +2019,8 @@ static void aclnn_roll(ggml_backend_cann_context& ctx, aclTensor* acl_src,
                        aclTensor* acl_dst, int64_t* shifts, int64_t* dims) {
     aclIntArray* acl_shifts = aclCreateIntArray(shifts, 1);
     aclIntArray* acl_dims = aclCreateIntArray(dims, 1);
-    GGML_CANN_CALL_ACLNN_OP(Roll, acl_src, acl_shifts, acl_dims, acl_dst);
-    ACL_CHECK(aclDestroyIntArray(acl_shifts));
-    ACL_CHECK(aclDestroyIntArray(acl_dims));
+    GGML_CANN_CALL_ACLNN_OP(ctx, Roll, acl_src, acl_shifts, acl_dims, acl_dst);
+    ggml_cann_release_resources(ctx, acl_shifts, acl_dims);
 }
 
 /**
@@ -2130,9 +2042,8 @@ static void aclnn_index_fill_tensor(ggml_backend_cann_context& ctx,
                                     float value) {
     aclIntArray* acl_index = aclCreateIntArray(index, index_num);
     aclScalar* acl_value = aclCreateScalar(&value, aclDataType::ACL_FLOAT);
-    GGML_CANN_CALL_ACLNN_OP(InplaceIndexFillTensor, acl_src, dim, acl_index, acl_value);
-    ACL_CHECK(aclDestroyIntArray(acl_index));
-    ACL_CHECK(aclDestroyScalar(acl_value));
+    GGML_CANN_CALL_ACLNN_OP(ctx, InplaceIndexFillTensor, acl_src, dim, acl_index, acl_value);
+    ggml_cann_release_resources(ctx, acl_index, acl_value);
 }
 
 static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
@@ -2169,7 +2080,8 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
 
     // power
     aclScalar* acl_theta_scale = aclCreateScalar(&theta_scale, aclDataType::ACL_FLOAT);
-    GGML_CANN_CALL_ACLNN_OP(PowScalarTensor, acl_theta_scale, acl_theta_scale_tensor, acl_theta_scale_tensor);
+    GGML_CANN_CALL_ACLNN_OP(ctx, PowScalarTensor, acl_theta_scale, acl_theta_scale_tensor,
+                            acl_theta_scale_tensor);
 
     // freq_scale
     if (freq_scale != 1) {
@@ -2182,7 +2094,7 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
             src2->data, ggml_cann_type_mapping(src2->type),
             ggml_type_size(src2->type), theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
         aclnn_div(ctx, acl_theta_scale_tensor, acl_freq_factors_tensor);
-        ACL_CHECK(aclDestroyTensor(acl_freq_factors_tensor));
+        ggml_cann_release_resources(ctx, acl_freq_factors_tensor);
     }
 
     // position
@@ -2251,12 +2163,8 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
     }
 
     // release
-    ACL_CHECK(aclDestroyTensor(acl_theta_scale_tensor));
-    ACL_CHECK(aclDestroyTensor(acl_position_tensor));
-    ACL_CHECK(aclDestroyTensor(acl_theta_tensor));
-    ACL_CHECK(aclDestroyTensor(acl_sin_tensor));
-    ACL_CHECK(aclDestroyTensor(acl_cos_tensor));
-    ACL_CHECK(aclDestroyScalar(acl_theta_scale));
+    ggml_cann_release_resources(ctx, acl_theta_scale_tensor, acl_position_tensor,
+        acl_theta_tensor, acl_sin_tensor, acl_cos_tensor, acl_theta_scale);
 }
 
 #ifdef __cplusplus
@@ -2368,8 +2276,7 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
         int64_t shifts[] = {1};
         int64_t dims[] = {3};
         aclnn_roll(ctx, acl_input_tensor, acl_input_roll_tensor, shifts, dims);
-        ACL_CHECK(aclDestroyTensor(acl_input_roll_tensor));
-        ACL_CHECK(aclDestroyTensor(acl_input_tensor));
+        ggml_cann_release_resources(ctx, acl_input_roll_tensor, acl_input_tensor);
 
         // init [-1, 1, -1, 1, ...]
         minus_one_scale_buffer = minus_one_scale_allocator.get();
@@ -2405,8 +2312,7 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
         int64_t dims[] = {3};
         aclnn_roll(ctx, acl_input_tensor, acl_input_roll_tensor, shifts, dims);
 
-        ACL_CHECK(aclDestroyTensor(acl_input_roll_tensor));
-        ACL_CHECK(aclDestroyTensor(acl_input_tensor));
+        ggml_cann_release_resources(ctx, acl_input_roll_tensor, acl_input_tensor);
         // init [-1, -1, -1, 1, 1,1,...]
         minus_one_scale_buffer = minus_one_scale_allocator.get();
         int64_t minus_one_ne[4] = {src0->ne[0], 1, 1, 1};
@@ -2431,7 +2337,7 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
         bool inplace = true;
         float scale = -1;
         aclnn_muls(ctx, acl_first_half_tensor, scale, nullptr, inplace);
-        ACL_CHECK(aclDestroyTensor(acl_first_half_tensor));
+        ggml_cann_release_resources(ctx, acl_first_half_tensor);
     }
 
     // TODO: n_dims < ne0
@@ -2496,14 +2402,10 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
                   output_fp32_tensor);
         aclnn_cast(ctx, output_fp32_tensor, acl_dst, ACL_FLOAT16);
 
-        ACL_CHECK(aclDestroyTensor(input_fp32_tensor1));
-        ACL_CHECK(aclDestroyTensor(input_fp32_tensor2));
-        ACL_CHECK(aclDestroyTensor(output_fp32_tensor));
-        ACL_CHECK(aclDestroyTensor(acl_sin_reshape_tensor));
-        ACL_CHECK(aclDestroyTensor(acl_minus_one_tensor));
-        ACL_CHECK(aclDestroyTensor(acl_input_roll_mul_scale_tensor));
-        ACL_CHECK(aclDestroyTensor(acl_input_roll_reshape_tensor));
-        ACL_CHECK(aclDestroyTensor(acl_src));
+        ggml_cann_release_resources(ctx, input_fp32_tensor1, input_fp32_tensor2,
+            output_fp32_tensor, acl_sin_reshape_tensor,
+            acl_minus_one_tensor, acl_input_roll_mul_scale_tensor,
+            acl_input_roll_reshape_tensor, acl_src);
     }
     return;
 #endif
@@ -2513,8 +2415,8 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
 
     switch (src0->type) {
         case GGML_TYPE_F32: {
-            GGML_CANN_CALL_ACLNN_OP(RotaryPositionEmbedding, acl_src, acl_cos_reshape_tensor,
-                acl_sin_reshape_tensor, acl_mode, acl_dst);
+            GGML_CANN_CALL_ACLNN_OP(ctx, RotaryPositionEmbedding, acl_src,
+                acl_cos_reshape_tensor, acl_sin_reshape_tensor, acl_mode, acl_dst);
             break;
         }
         case GGML_TYPE_F16: {
@@ -2540,23 +2442,22 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
 
             aclnn_cast(ctx, acl_src, acl_src_trans_tensor, ACL_FLOAT);
 
-            GGML_CANN_CALL_ACLNN_OP(RotaryPositionEmbedding, acl_src_trans_tensor, acl_cos_reshape_tensor,
-                acl_sin_reshape_tensor, acl_mode, acl_dst_trans_tensor);
+            GGML_CANN_CALL_ACLNN_OP(ctx, RotaryPositionEmbedding, acl_src_trans_tensor,
+                acl_cos_reshape_tensor, acl_sin_reshape_tensor, acl_mode,
+                acl_dst_trans_tensor);
 
             aclnn_cast(ctx, acl_dst_trans_tensor, acl_dst, ACL_FLOAT16);
 
-            ACL_CHECK(aclDestroyTensor(acl_src_trans_tensor));
-            ACL_CHECK(aclDestroyTensor(acl_dst_trans_tensor));
+            ggml_cann_release_resources(ctx, acl_src_trans_tensor,
+                acl_dst_trans_tensor);
             break;
         }
         default:
             GGML_ABORT("Unsupported tensor type for GGML_OP_ROPE");
             break;
     }
-    ACL_CHECK(aclDestroyTensor(acl_src));
-    ACL_CHECK(aclDestroyTensor(acl_cos_reshape_tensor));
-    ACL_CHECK(aclDestroyTensor(acl_sin_reshape_tensor));
-    ACL_CHECK(aclDestroyTensor(acl_dst));
+    ggml_cann_release_resources(ctx, acl_cos_reshape_tensor,
+        acl_sin_reshape_tensor, acl_src, acl_dst);
 }
 
 
@@ -2566,10 +2467,9 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
     aclTensor* acl_src = ggml_cann_create_tensor(src0);
     aclTensor* acl_dst = ggml_cann_create_tensor(dst, dst->ne, dst->nb, 3);
 
-    GGML_CANN_CALL_ACLNN_OP(ArgMax, acl_src, 3, false, acl_dst);
+    GGML_CANN_CALL_ACLNN_OP(ctx, ArgMax, acl_src, 3, false, acl_dst);
 
-    ACL_CHECK(aclDestroyTensor(acl_src));
-    ACL_CHECK(aclDestroyTensor(acl_dst));
+    ggml_cann_release_resources(ctx, acl_src, acl_dst);
 }
 
 void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* dst){
@@ -2598,14 +2498,10 @@ void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* ds
     cubeMathType = 1;
 #endif
 
-    GGML_CANN_CALL_ACLNN_OP(Convolution, acl_input, acl_weight, nullptr, stride,
+    GGML_CANN_CALL_ACLNN_OP(ctx, Convolution, acl_input, acl_weight, nullptr, stride,
         padding, dilation, transposed, padding, groups, acl_dst, cubeMathType);
 
-    ACL_CHECK(aclDestroyTensor(acl_weight));
-    ACL_CHECK(aclDestroyTensor(acl_dst));
-    ACL_CHECK(aclDestroyIntArray(stride));
-    ACL_CHECK(aclDestroyIntArray(padding));
-    ACL_CHECK(aclDestroyIntArray(dilation));
+    ggml_cann_release_resources(ctx, acl_weight, acl_dst, stride, padding, dilation);
 }
 
 void ggml_cann_elu(ggml_backend_cann_context& ctx, ggml_tensor* dst){
@@ -2618,12 +2514,10 @@ void ggml_cann_elu(ggml_backend_cann_context& ctx, ggml_tensor* dst){
     aclScalar* alpha = nullptr;
     alpha = aclCreateScalar(&alphaValue, aclDataType::ACL_FLOAT);
 
-    GGML_CANN_CALL_ACLNN_OP(Elu, acl_input, alpha, alpha, alpha,
+    GGML_CANN_CALL_ACLNN_OP(ctx, Elu, acl_input, alpha, alpha, alpha,
         acl_dst);
 
-    ACL_CHECK(aclDestroyTensor(acl_input));
-    ACL_CHECK(aclDestroyTensor(acl_dst));
-    ACL_CHECK(aclDestroyScalar(alpha));
+    ggml_cann_release_resources(ctx, acl_input, acl_dst, alpha);
 }
 
 void ggml_cann_mean(ggml_backend_cann_context& ctx, ggml_tensor* dst){
@@ -2636,11 +2530,9 @@ void ggml_cann_mean(ggml_backend_cann_context& ctx, ggml_tensor* dst){
     aclIntArray* reduceDim = aclCreateIntArray(reduceDimValue, 1);
     bool keepDim = true;
 
-    GGML_CANN_CALL_ACLNN_OP(Mean, acl_src, reduceDim, keepDim, ACL_FLOAT, acl_dst);
+    GGML_CANN_CALL_ACLNN_OP(ctx, Mean, acl_src, reduceDim, keepDim, ACL_FLOAT, acl_dst);
 
-    ACL_CHECK(aclDestroyTensor(acl_src));
-    ACL_CHECK(aclDestroyTensor(acl_dst));
-    ACL_CHECK(aclDestroyIntArray(reduceDim));
+    ggml_cann_release_resources(ctx, acl_src, acl_dst, reduceDim);
 }
 
 void ggml_cann_pad_reflect_1d(ggml_backend_cann_context& ctx, ggml_tensor* dst){
@@ -2660,12 +2552,11 @@ void ggml_cann_pad_reflect_1d(ggml_backend_cann_context& ctx, ggml_tensor* dst){
             ggml_cann_type_mapping(dst->type), ggml_element_size(dst),
             dst->ne, dst->nb, 3);
 
-            GGML_CANN_CALL_ACLNN_OP(ReflectionPad1d, acl_src, paddings, acl_dst);
+            GGML_CANN_CALL_ACLNN_OP(ctx, ReflectionPad1d, acl_src, paddings, acl_dst);
 
-            ACL_CHECK(aclDestroyTensor(acl_src));
-            ACL_CHECK(aclDestroyTensor(acl_dst));
+            ggml_cann_release_resources(ctx, acl_src, acl_dst);
     }
-    ACL_CHECK(aclDestroyIntArray(paddings));
+    ggml_cann_release_resources(ctx, paddings);
 }
 
 void ggml_cann_count_equal(ggml_backend_cann_context& ctx, ggml_tensor* dst){
@@ -2675,12 +2566,11 @@ void ggml_cann_count_equal(ggml_backend_cann_context& ctx, ggml_tensor* dst){
     aclTensor* acl_self = ggml_cann_create_tensor(src0);
     aclTensor* acl_other = ggml_cann_create_tensor(src1);
 
-    GGML_CANN_CALL_ACLNN_OP(InplaceEqTensor, acl_self, acl_other);
+    GGML_CANN_CALL_ACLNN_OP(ctx, InplaceEqTensor, acl_self, acl_other);
 
     ggml_cann_sum(ctx, dst);
 
-    ACL_CHECK(aclDestroyTensor(acl_self));
-    ACL_CHECK(aclDestroyTensor(acl_other));
+    ggml_cann_release_resources(ctx, acl_self, acl_other);
 }
 
 void ggml_cann_step(ggml_backend_cann_context& ctx, ggml_tensor* dst){
@@ -2693,9 +2583,7 @@ void ggml_cann_step(ggml_backend_cann_context& ctx, ggml_tensor* dst){
     aclScalar* alpha = nullptr;
     alpha = aclCreateScalar(&alphaValue, aclDataType::ACL_FLOAT);
 
-    GGML_CANN_CALL_ACLNN_OP(GtScalar, acl_src, alpha, acl_dst);
+    GGML_CANN_CALL_ACLNN_OP(ctx, GtScalar, acl_src, alpha, acl_dst);
 
-    ACL_CHECK(aclDestroyTensor(acl_src));
-    ACL_CHECK(aclDestroyTensor(acl_dst));
-    ACL_CHECK(aclDestroyScalar(alpha));
+    ggml_cann_release_resources(ctx, acl_src, acl_dst, alpha);
 }
diff --git a/ggml/src/ggml-cann/aclnn_ops.h b/ggml/src/ggml-cann/aclnn_ops.h
index b2d1b3c36..462351542 100644
--- a/ggml/src/ggml-cann/aclnn_ops.h
+++ b/ggml/src/ggml-cann/aclnn_ops.h
@@ -23,6 +23,7 @@
 #ifndef CANN_ACLNN_OPS
 #define CANN_ACLNN_OPS
 
+#include 
 #include 
 #include 
 #include 
@@ -713,6 +714,270 @@ void ggml_cann_count_equal(ggml_backend_cann_context& ctx, ggml_tensor* dst);
  */
 void ggml_cann_step(ggml_backend_cann_context& ctx, ggml_tensor* dst);
 
+/*
+ * @brief A generic wrapper for ACL resources with custom deleter support.
+ */
+using any_acl_resource = std::unique_ptr>;
+
+/**
+ * @brief Trait structure used to define how to destroy a given ACL resource type.
+ *
+ * @tparam T ACL resource type.
+ */
+template
+struct acl_resource_traits;
+
+/**
+ * @brief Specialization for aclTensor, defines how to destroy an aclTensor resource.
+ */
+template<>
+struct acl_resource_traits {
+    static void destroy(void* p) {
+        ACL_CHECK(aclDestroyTensor(static_cast(p)));
+    }
+};
+
+/**
+ * @brief Specialization for aclIntArray, defines how to destroy an aclIntArray resource.
+ */
+template<>
+struct acl_resource_traits {
+    static void destroy(void* p) {
+        ACL_CHECK(aclDestroyIntArray(static_cast(p)));
+    }
+};
+
+/**
+ * @brief Specialization for aclScalar, defines how to destroy an aclScalar resource.
+ */
+template<>
+struct acl_resource_traits {
+    static void destroy(void* p) {
+        ACL_CHECK(aclDestroyScalar(static_cast(p)));
+    }
+};
+
+/**
+ * @brief Specialization for aclTensorList, defines how to destroy an aclTensorList resource.
+ */
+template<>
+struct acl_resource_traits {
+    static void destroy(void* p) {
+        ACL_CHECK(aclDestroyTensorList(static_cast(p)));
+    }
+};
+
+/**
+ * @brief Creates a generic ACL resource wrapper with proper destruction logic.
+ *
+ * @tparam T ACL resource type.
+ * @param ptr Raw pointer to ACL resource.
+ * @return any_acl_resource Smart pointer that handles destruction.
+ */
+template
+any_acl_resource make_acl_resource(T* ptr) {
+    return any_acl_resource(
+        static_cast(ptr),
+        [](void* p) {
+            acl_resource_traits::destroy(p);
+        }
+    );
+}
+
+/**
+ * @brief Registers multiple ACL resources into a vector for lifetime management.
+ *
+ * @tparam Args Variadic list of ACL resource types.
+ * @param vec Target vector to hold ACL resources.
+ * @param args Raw pointers to ACL resources.
+ */
+template
+void register_acl_resources(std::vector& vec, Args*... args) {
+    (vec.emplace_back(make_acl_resource(args)), ...);
+}
+
+/**
+ * @brief Task class that wraps the execution of an aclnn function call.
+ */
+class aclnn_task : public cann_task {
+    public:
+        aclnn_task(aclnn_func_t aclnn_func, void * workspace_addr,
+                   uint64_t workspace_size, aclOpExecutor * executor,
+                   aclrtStream stream) :
+            aclnn_func_(aclnn_func),
+            workspace_addr_(workspace_addr),
+            workspace_size_(workspace_size),
+            executor_(executor),
+            stream_(stream) {}
+        virtual void run_task() override {
+            ACL_CHECK(aclnn_func_(workspace_addr_, workspace_size_, executor_, stream_));
+        }
+    private:
+        aclnn_func_t aclnn_func_;
+        void *          workspace_addr_;
+        uint64_t        workspace_size_;
+        aclOpExecutor * executor_;
+        aclrtStream     stream_;
+};
+
+/**
+ * @brief Task class that releases ACL resources after usage.
+ */
+class release_resource_task : public cann_task {
+public:
+    release_resource_task(std::vector&& resources){
+        resource_ = std::move(resources);
+    }
+
+    virtual void run_task() override {
+        resource_.clear();
+    }
+private:
+    std::vector resource_;
+};
+
+/**
+ * @brief Task class for performing asynchronous memory copy operations.
+ */
+class async_memcpy_task : public cann_task {
+public:
+    async_memcpy_task(void* dst, const void* src, size_t size,
+                      aclrtMemcpyKind kind, aclrtStream stream)
+        : dst_(dst), src_(src), size_(size), kind_(kind), stream_(stream) {}
+
+    virtual void run_task() override {
+        ACL_CHECK(aclrtMemcpyAsync(dst_, size_, src_, size_, kind_, stream_));
+    }
+private:
+    void* dst_;
+    const void* src_;
+    size_t size_;
+    aclrtMemcpyKind kind_;
+    aclrtStream stream_;
+};
+
+/**
+ * @brief Task class for performing asynchronous memory set operations.
+ */
+class async_memset_task : public cann_task {
+    public:
+    async_memset_task(void* buffer, size_t size, int32_t value, aclrtStream stream)
+            : buffer_(buffer), size_(size), value_(value), stream_(stream) {}
+
+        virtual void run_task() override {
+            ACL_CHECK(aclrtMemsetAsync(buffer_, size_, value_, size_, stream_));
+        }
+    private:
+        void* buffer_;
+        size_t size_;
+        int32_t value_;
+        aclrtStream stream_;
+};
+
+/**
+ * @brief Launches an asynchronous task using the memory allocator.
+ *
+ * This macro submit an asynchronous task on the specified stream.
+ * The task uses memory allocated by the allocator. It is guaranteed
+ * that the memory will not be accessed by other tasks until this task
+ * completes, due to the sequential execution order within the same stream.
+ *
+ * @param OP_NAME aclnn operator name.
+ * @param args Additional arguments required by the task.
+ *
+ * @note
+ * Memory from the allocator will be "freed" immediately and can be
+ * reallocated to other pointers. However, it won't be accessed by any
+ * other task before this asynchronous task ends, because all tasks in the
+ * same stream are executed in queue order.
+ */
+
+#define GGML_CANN_CALL_ACLNN_OP(CTX, OP_NAME, ...)                                          \
+    do {                                                                                    \
+        uint64_t        workspaceSize = 0;                                                  \
+        aclOpExecutor * executor;                                                           \
+        void *          workspaceAddr = nullptr;                                            \
+        ACL_CHECK(aclnn##OP_NAME##GetWorkspaceSize(__VA_ARGS__, &workspaceSize, &executor));\
+        /* workspace should alloced in main thread to keep malloc order when using vmm. */  \
+        if (workspaceSize > 0) {                                                            \
+            ggml_cann_pool_alloc workspace_allocator(CTX.pool(), workspaceSize);            \
+            workspaceAddr = workspace_allocator.get();                                      \
+        }                                                                                   \
+        if (CTX.async_mode) {                                                               \
+            auto task =                                                                     \
+                std::make_unique(aclnn##OP_NAME, workspaceAddr, workspaceSize,  \
+                    executor, CTX.stream()); \
+            CTX.task_queue.submit_task(std::move(task));                                    \
+        } else {                                                                            \
+            ACL_CHECK(aclnn##OP_NAME(workspaceAddr, workspaceSize, executor, CTX.stream()));\
+        }                                                                                   \
+    } while (0)
+
+/**
+ * @brief Registers and releases multiple ACL resources, optionally deferring the release
+ *        using a task.
+ *
+ * @tparam Args Types of the ACL resources.
+ * @param ctx Backend context which manages task submission and async mode.
+ * @param args Pointers to ACL resources to be released.
+ */
+template 
+void ggml_cann_release_resources(ggml_backend_cann_context & ctx, Args &&... args) {
+    std::vector resources;
+    register_acl_resources(resources, std::forward(args)...);
+    if(ctx.async_mode) {
+        auto task = std::make_unique(std::move(resources));
+        ctx.task_queue.submit_task(std::move(task));
+    }
+}
+
+/**
+ * @brief Performs an asynchronous memory copy operation, optionally deferred via task submission.
+ *
+ * @param ctx Backend context containing stream and async configuration.
+ * @param dst Destination memory address.
+ * @param src Source memory address.
+ * @param len Size of memory to copy (in bytes).
+ * @param kind Type of memory copy (host-to-device, device-to-host, etc).
+ */
+inline void ggml_cann_async_memcpy(ggml_backend_cann_context & ctx, void * dst,
+                                   const void * src, size_t len, aclrtMemcpyKind kind) {
+    if (ctx.async_mode) {
+        auto task = std::make_unique(dst, const_cast(src), len, kind, ctx.stream());
+        ctx.task_queue.submit_task(std::move(task));
+    } else {
+        ACL_CHECK(aclrtMemcpyAsync(dst, len, src, len, kind, ctx.stream()));
+    }
+}
+
+inline void ggml_cann_async_memcpy(ggml_backend_cann_context * ctx, void * dst,
+                                   const void * src, size_t len, aclrtMemcpyKind kind) {
+    if (ctx->async_mode) {
+        auto task = std::make_unique(dst, const_cast(src), len, kind, ctx->stream());
+        ctx->task_queue.submit_task(std::move(task));
+    } else {
+        ACL_CHECK(aclrtMemcpyAsync(dst, len, src, len, kind, ctx->stream()));
+    }
+}
+
+/**
+ * @brief Performs an asynchronous memory set operation, optionally deferred via task submission.
+ *
+ * @param ctx Backend context containing stream and async configuration.
+ * @param buffer Memory buffer to be set.
+ * @param size Size of the memory buffer (in bytes).
+ * @param value Value to set in the buffer.
+ */
+inline void ggml_cann_async_memset(ggml_backend_cann_context & ctx, void * buffer,
+                                   size_t size, int value) {
+    if (ctx.async_mode) {
+        auto task = std::make_unique(buffer, size, value, ctx.stream());
+        ctx.task_queue.submit_task(std::move(task));
+    } else {
+        ACL_CHECK(aclrtMemsetAsync(buffer, size, value, size, ctx.stream()));
+    }
+}
+
 /**
  * @brief Applies a element-wise operation to two input tensors using the CANN
  * backend.
@@ -742,42 +1007,9 @@ void ggml_cann_binary_op(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
     bcast_shape(src0, src1, dst, &acl_src0, &acl_src1, &acl_dst);
     binary_op(ctx, acl_src0, acl_src1, acl_dst);
 
-    ACL_CHECK(aclDestroyTensor(acl_src0));
-    ACL_CHECK(aclDestroyTensor(acl_src1));
-    ACL_CHECK(aclDestroyTensor(acl_dst));
+    ggml_cann_release_resources(ctx, acl_src0, acl_src1, acl_dst);
 }
 
-/**
- * @brief Launches an asynchronous task using the memory allocator.
- *
- * This macro submit an asynchronous task on the specified stream.
- * The task uses memory allocated by the allocator. It is guaranteed
- * that the memory will not be accessed by other tasks until this task
- * completes, due to the sequential execution order within the same stream.
- *
- * @param OP_NAME aclnn operator name.
- * @param args Additional arguments required by the task.
- *
- * @note
- * Memory from the allocator will be "freed" immediately and can be
- * reallocated to other pointers. However, it won't be accessed by any
- * other task before this asynchronous task ends, because all tasks in the
- * same stream are executed in queue order.
- */
-#define GGML_CANN_CALL_ACLNN_OP(OP_NAME, ...)                                                \
-    do {                                                                                     \
-        uint64_t        workspaceSize = 0;                                                   \
-        aclOpExecutor * executor;                                                            \
-        void *          workspaceAddr = nullptr;                                             \
-                                                                                             \
-        ACL_CHECK(aclnn##OP_NAME##GetWorkspaceSize(__VA_ARGS__, &workspaceSize, &executor)); \
-                                                                                             \
-        if (workspaceSize > 0) {                                                             \
-            ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);             \
-            workspaceAddr = workspace_allocator.get();                                       \
-        }                                                                                    \
-        ACL_CHECK(aclnn##OP_NAME(workspaceAddr, workspaceSize, executor, ctx.stream()));     \
-    } while (0)
 
 /**
  * @brief Applies a unary operation to an input tensor using the CANN backend.
@@ -799,9 +1031,7 @@ template 
     aclTensor* acl_dst = ggml_cann_create_tensor(dst);
 
     unary_op(ctx, acl_src, acl_dst);
-
-    ACL_CHECK(aclDestroyTensor(acl_src));
-    ACL_CHECK(aclDestroyTensor(acl_dst));
+    ggml_cann_release_resources(ctx, acl_src, acl_dst);
 }
 
 /**
@@ -832,7 +1062,7 @@ void ggml_cann_unary_op(
  *
  * Internally, the lambda will call:
  * @code
- * GGML_CANN_CALL_ACLNN_OP(OP_NAME, acl_src, acl_dst);
+ * GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst);
  * @endcode
  *
  * @param OP_NAME The name of the ACL unary operator to invoke via GGML_CANN_CALL_ACLNN_OP.
@@ -840,14 +1070,14 @@ void ggml_cann_unary_op(
  * @see ggml_cann_unary_op
  * @see GGML_CANN_CALL_ACLNN_OP
  */
-#define GGML_CANN_CALL_UNARY_OP(OP_NAME)                         \
-    do {                                                         \
-        auto lambda = [](ggml_backend_cann_context& ctx,         \
-            aclTensor* acl_src,                                  \
-            aclTensor* acl_dst) {                                \
-            GGML_CANN_CALL_ACLNN_OP(OP_NAME, acl_src, acl_dst);  \
-        };                                                       \
-        ggml_cann_unary_op(lambda, ctx, dst);                    \
-    }                                                            \
+#define GGML_CANN_CALL_UNARY_OP(OP_NAME)                              \
+    do {                                                              \
+        auto lambda = [](ggml_backend_cann_context& ctx,              \
+            aclTensor* acl_src,                                       \
+            aclTensor* acl_dst) {                                     \
+            GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst);  \
+        };                                                            \
+        ggml_cann_unary_op(lambda, ctx, dst);                         \
+    }                                                                 \
     while (0)
 #endif  // CANN_ACLNN_OPS
diff --git a/ggml/src/ggml-cann/common.h b/ggml/src/ggml-cann/common.h
index 5164cb74e..7ef80a479 100644
--- a/ggml/src/ggml-cann/common.h
+++ b/ggml/src/ggml-cann/common.h
@@ -31,9 +31,16 @@
 #include 
 #include 
 #include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
 
 #include "../include/ggml-cann.h"
 #include "../include/ggml.h"
+#include "../ggml-impl.h"
 
 #define MATRIX_ROW_PADDING 512
 #define GGML_CANN_MAX_STREAMS 8
@@ -205,6 +212,127 @@ struct ggml_cann_pool_alloc {
     ggml_cann_pool_alloc& operator=(ggml_cann_pool_alloc&&) = delete;
 };
 
+/**
+ * @brief Function pointer type for ACLNN operator calls.
+ */
+using aclnn_func_t = aclnnStatus (*)(void*, uint64_t, aclOpExecutor*, aclrtStream);
+
+/**
+ * @brief Base class for all CANN tasks to be submitted to the task queue.
+ *
+ * Users should override the run_task() method with actual task logic.
+ */
+class cann_task {
+public:
+    virtual void run_task() {}
+};
+
+/**
+ * @brief A lock-free ring-buffer based task queue for asynchronously executing cann_task instances.
+ */
+class cann_task_queue {
+public:
+    /**
+     * @brief Constructs a task queue with a fixed power-of-two capacity for a specific device.
+     *
+     * @param capacity Queue capacity. Must be a power of 2.
+     * @param device Target device ID (used for context setting).
+     */
+    explicit cann_task_queue(size_t capacity, int32_t device)
+        : buffer_(capacity), capacity_(capacity), head_(0), tail_(0),
+          running_(false), device_(device) {
+        GGML_ASSERT((capacity & (capacity - 1)) == 0 && "capacity must be power of 2");
+        mask_ = capacity_ - 1;
+    }
+
+    /**
+     * @brief Attempts to enqueue a task into the queue.
+     *
+     * @param item Unique pointer to the task.
+     * @return true if the task was successfully enqueued, false if the queue was full.
+     */
+    bool enqueue(std::unique_ptr&& item) {
+        size_t next_tail = (tail_ + 1) & mask_;
+
+        if (next_tail == head_) {
+            return false;
+        }
+
+        buffer_[tail_] = std::move(item);
+        std::atomic_thread_fence(std::memory_order_release);
+        tail_ = next_tail;
+
+        return true;
+    }
+
+    /**
+     * @brief Submits a task to the queue, and starts the worker thread if not already running.
+     *
+     * @param task Task to be submitted.
+     */
+    void submit_task(std::unique_ptr&& task) {
+        while(!enqueue(std::move(task))) {
+            std::this_thread::yield();
+            continue;
+        }
+
+        if (!running_) {
+            running_ = true;
+            thread_ = std::thread(&cann_task_queue::execute, this);
+        }
+
+    }
+
+    /**
+     * @brief Waits until the queue is completely empty and no tasks are being processed.
+     */
+    void wait() {
+        while (running_ && head_ != tail_) {
+            std::this_thread::yield();
+            continue;
+        }
+    }
+
+    /**
+     * @brief Stops the task queue and joins the worker thread.
+     */
+    void stop() {
+        running_ = false;
+        if (thread_.joinable()) {
+            thread_.join();
+        }
+    }
+
+private:
+    /**
+     * @brief Worker thread function that continuously dequeues and executes tasks.
+     */
+    void execute() {
+        ggml_cann_set_device(device_);
+
+        while (running_) {
+            if(head_ == tail_) {
+                std::this_thread::yield();
+                continue;
+            }
+
+            std::atomic_thread_fence(std::memory_order_acquire);
+            buffer_[head_]->run_task();
+            buffer_[head_].reset();
+            head_ = (head_ + 1) & mask_;
+        }
+    }
+
+    std::vector> buffer_;
+    const size_t capacity_;
+    size_t mask_;
+    size_t head_;
+    size_t tail_;
+    bool running_;
+    std::thread thread_;
+    int32_t device_;
+};
+
 /**
  * @brief Context for managing CANN backend operations.
  */
@@ -213,6 +341,8 @@ struct ggml_backend_cann_context {
     std::string name;                /**< Name of the device. */
     std::string description;         /**< Description of the device. */
     aclrtEvent copy_event = nullptr; /**< Event for managing copy operations. */
+    cann_task_queue task_queue;
+    bool async_mode;
 
     aclrtStream streams[GGML_CANN_MAX_STREAMS] = {nullptr}; /**< Array of streams for the device. */
 
@@ -221,9 +351,12 @@ struct ggml_backend_cann_context {
      * @param device Device ID.
      */
     explicit ggml_backend_cann_context(int device)
-        : device(device), name("CANN" + std::to_string(device)) {
+        : device(device), name("CANN" + std::to_string(device)), task_queue(1024, device) {
         ggml_cann_set_device(device);
         description = aclrtGetSocName();
+        async_mode = (getenv("GGML_CANN_ASYNC_MODE") != nullptr);
+        GGML_LOG_INFO("%s: device %d async operator submission is %s\n", __func__,
+            device, async_mode ? "ON" : "OFF");
     }
 
     /**
@@ -231,6 +364,7 @@ struct ggml_backend_cann_context {
      */
     ~ggml_backend_cann_context() {
         ggml_cann_set_device(device);
+        task_queue.stop();
         if (copy_event != nullptr) {
             ACL_CHECK(aclrtDestroyEvent(copy_event));
         }
diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp
index ca41e0260..e2617b06e 100644
--- a/ggml/src/ggml-cann/ggml-cann.cpp
+++ b/ggml/src/ggml-cann/ggml-cann.cpp
@@ -1606,7 +1606,7 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
                     auto lambda = [](ggml_backend_cann_context& ctx,
                         aclTensor* acl_src,
                         aclTensor* acl_dst) {
-                        GGML_CANN_CALL_ACLNN_OP(GeluV2, acl_src, 0, acl_dst);
+                        GGML_CANN_CALL_ACLNN_OP(ctx, GeluV2, acl_src, 0, acl_dst);
                     };
                     ggml_cann_unary_op(lambda, ctx, dst);
                 } break;
@@ -1789,12 +1789,11 @@ static void ggml_backend_cann_free(ggml_backend_t backend) {
     delete backend;
 }
 
+
 /**
  * @brief Sets tensor data asynchronously in the CANN backend.
  *
- * This function asynchronously sets tensor data in the CANN backend. Depending
- * on the tensor type, it may perform data transformations before copying data
- * to the device.
+ * This function asynchronously sets tensor data in the CANN backend.
  *
  * @param backend Pointer to the CANN backend structure.
  * @param tensor Pointer to the tensor structure to set data for.
@@ -1809,23 +1808,28 @@ static void ggml_backend_cann_set_tensor_async(ggml_backend_t backend,
                                                size_t size) {
     ggml_backend_cann_context *cann_ctx =
         (ggml_backend_cann_context *)backend->context;
+    ggml_backend_buffer_t buf =
+        tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
 
-    if (!need_transform(tensor->type)) {
-        ACL_CHECK(aclrtMemcpyAsync((char *)tensor->data + offset, size, data,
-                                   size, ACL_MEMCPY_HOST_TO_DEVICE,
-                                   cann_ctx->stream()));
-    } else {
-        void *transform_buffer = malloc(size);
-        ggml_backend_cann_transform(tensor, data, transform_buffer);
+    GGML_ASSERT(buf->buft == ggml_backend_cann_buffer_type(cann_ctx->device) &&
+        "unsupported buffer type");
+    GGML_ASSERT(!ggml_is_quantized(tensor->type));
 
-        ACL_CHECK(aclrtMemcpyAsync(
-            (char *)tensor->data + offset, size, transform_buffer, size,
-            ACL_MEMCPY_HOST_TO_DEVICE, cann_ctx->stream()));
-        ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
-        free(transform_buffer);
-    }
+    ggml_cann_async_memcpy(cann_ctx, (char *)tensor->data + offset, data, size,
+        ACL_MEMCPY_HOST_TO_DEVICE);
 }
 
+/**
+ * @brief Gets tensor data asynchronously in the CANN backend.
+ *
+ * This function asynchronously gets tensor data in the CANN backend.
+ *
+ * @param backend Pointer to the CANN backend structure.
+ * @param tensor Pointer to the tensor structure to get data from.
+ * @param data Pointer to the host data to copy from the tensor.
+ * @param offset Offset in bytes within the host data.
+ * @param size Size of the data to copy in bytes.
+ */
 static void ggml_backend_cann_get_tensor_async(
     ggml_backend_t backend, const ggml_tensor *tensor, void *data,
     size_t offset, size_t size) {
@@ -1836,20 +1840,11 @@ static void ggml_backend_cann_get_tensor_async(
 
     GGML_ASSERT(buf->buft == ggml_backend_cann_buffer_type(cann_ctx->device) &&
                 "unsupported buffer type");
+    GGML_ASSERT(!ggml_is_quantized(tensor->type));
+
+    ggml_cann_async_memcpy(cann_ctx, data, (char *)tensor->data + offset, size,
+        ACL_MEMCPY_DEVICE_TO_HOST);
 
-    if (!need_transform(tensor->type)) {
-        ACL_CHECK(aclrtMemcpyAsync(data, size, (char *)tensor->data + offset,
-                                   size, ACL_MEMCPY_DEVICE_TO_HOST,
-                                   cann_ctx->stream()));
-    } else {
-        void *transform_buffer = malloc(size);
-        ACL_CHECK(aclrtMemcpyAsync(
-            transform_buffer, size, (char *)tensor->data + offset, size,
-            ACL_MEMCPY_DEVICE_TO_HOST, cann_ctx->stream()));
-        ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
-        ggml_backend_cann_transform_back(tensor, transform_buffer, data);
-        free(transform_buffer);
-    }
 }
 
 /**
@@ -1909,6 +1904,8 @@ static bool ggml_backend_cann_cpy_tensor_async(
         ggml_cann_set_device(cann_ctx_src->device);
         ACL_CHECK(aclrtDeviceEnablePeerAccess(cann_ctx_dst->device, 0));
 
+        // wait for task_queue empty to keep task order.
+        cann_ctx_src->task_queue.wait();
         ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size,
                                    ACL_MEMCPY_DEVICE_TO_DEVICE,
                                    cann_ctx_src->stream()));
@@ -1936,9 +1933,8 @@ static bool ggml_backend_cann_cpy_tensor_async(
 static void ggml_backend_cann_synchronize(ggml_backend_t backend) {
     ggml_backend_cann_context* cann_ctx =
         (ggml_backend_cann_context*)backend->context;
-
+    cann_ctx->task_queue.wait();
     ggml_cann_set_device(cann_ctx->device);
-
     ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
 }
 

From 207c22ec2d6d793fc70830138617d1e016c5151c Mon Sep 17 00:00:00 2001
From: Alan Gray 
Date: Thu, 17 Apr 2025 14:19:42 +0100
Subject: [PATCH 5/6] ggml: Re-enable CUDA graphs in presence of CONT and DUP
 nodes (#12970)

---
 ggml/src/ggml-cuda/cpy.cu       | 9 +++++----
 ggml/src/ggml-cuda/cpy.cuh      | 2 +-
 ggml/src/ggml-cuda/ggml-cuda.cu | 2 +-
 3 files changed, 7 insertions(+), 6 deletions(-)

diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu
index 4f4faa3e6..ed25646e8 100644
--- a/ggml/src/ggml-cuda/cpy.cu
+++ b/ggml/src/ggml-cuda/cpy.cu
@@ -551,7 +551,7 @@ static void ggml_cpy_f16_f16_cuda(
         (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
 }
 
-void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1) {
+void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1, bool disable_indirection_for_this_node) {
     const int64_t ne = ggml_nelements(src0);
     GGML_ASSERT(ne == ggml_nelements(src1));
 
@@ -588,7 +588,7 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
     char ** dest_ptrs_d = nullptr;
     int graph_cpynode_index = -1;
 #if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS)
-    if(ctx.cuda_graph->use_cpy_indirection) {
+    if(ctx.cuda_graph->use_cpy_indirection && !disable_indirection_for_this_node) {
         dest_ptrs_d = ctx.cuda_graph->dest_ptrs_d;
         graph_cpynode_index = ctx.cuda_graph->graph_cpynode_index;
     }
@@ -636,7 +636,7 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
                 ggml_type_name(src0->type), ggml_type_name(src1->type));
     }
 #if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS)
-    if(ctx.cuda_graph->use_cpy_indirection) {
+    if(ctx.cuda_graph->use_cpy_indirection && !disable_indirection_for_this_node) {
         ctx.cuda_graph->graph_cpynode_index = graph_cpynode_index;
     }
 #endif
@@ -645,7 +645,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
 
 void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * src0 = dst->src[0];
-    ggml_cuda_cpy(ctx, src0, dst);
+    bool disable_indirection = true;
+    ggml_cuda_cpy(ctx, src0, dst, disable_indirection);
 }
 
 void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
diff --git a/ggml/src/ggml-cuda/cpy.cuh b/ggml/src/ggml-cuda/cpy.cuh
index 6bed0564d..0bd3c0c6f 100644
--- a/ggml/src/ggml-cuda/cpy.cuh
+++ b/ggml/src/ggml-cuda/cpy.cuh
@@ -2,7 +2,7 @@
 
 #define CUDA_CPY_BLOCK_SIZE 64
 
-void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1);
+void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1,  bool disable_indirection = false);
 
 void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index 9ced46651..bab85809a 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -2489,7 +2489,7 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
 #endif
         }
 
-        if (node->op == GGML_OP_MUL_MAT_ID || node->op == GGML_OP_CONT || node->op == GGML_OP_DUP) {
+        if (node->op == GGML_OP_MUL_MAT_ID) {
             use_cuda_graph = false; // This node type is not supported by CUDA graph capture
 #ifndef NDEBUG
             GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported node type\n", __func__);

From 2f74c354c0f752ed9aabf7d3a350e6edebd7e744 Mon Sep 17 00:00:00 2001
From: Georgi Gerganov 
Date: Thu, 17 Apr 2025 18:16:36 +0300
Subject: [PATCH 6/6] graph : make FA compatible with MLA + add initial Metal
 kernels (#12953)

* graph : make mla compatible with FA

* metal : add exp FA kernels for DeepSeek models

ggml-ci

* llama : minor naming updates

ggml-ci

* ggml : disable FA for DS head sizes

* tests : add FA tests for MLA shapes

ggml-ci
---
 ggml/src/ggml-cuda/ggml-cuda.cu      |  4 ++
 ggml/src/ggml-metal/ggml-metal.m     | 83 ++++++++++++++++++++++++++--
 ggml/src/ggml-metal/ggml-metal.metal | 17 ++++++
 ggml/src/ggml-vulkan/ggml-vulkan.cpp |  1 +
 src/llama-context.cpp                | 11 +---
 src/llama-graph.cpp                  | 14 +++--
 src/llama-model.cpp                  |  6 +-
 tests/test-backend-ops.cpp           |  7 ++-
 8 files changed, 117 insertions(+), 26 deletions(-)

diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index bab85809a..a7febef72 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -3237,6 +3237,10 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
             if (op->src[0]->ne[0] == 192) {
                 return false;
             }
+            if (op->src[0]->ne[0] == 576) {
+                // DeepSeek MLA
+                return false;
+            }
             if (op->src[0]->ne[3] != 1) {
                 return false;
             }
diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m
index 16ca08383..85f3ae7bf 100644
--- a/ggml/src/ggml-metal/ggml-metal.m
+++ b/ggml/src/ggml-metal/ggml-metal.m
@@ -354,6 +354,7 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H192,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96,
@@ -362,6 +363,7 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H192,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96,
@@ -370,6 +372,7 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H192,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96,
@@ -378,6 +381,7 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H192,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96,
@@ -386,6 +390,7 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H192,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96,
@@ -394,6 +399,7 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H192,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96,
@@ -402,6 +408,7 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96,
@@ -437,6 +444,13 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK576_HV512,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK576_HV512,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK576_HV512,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512,
     GGML_METAL_KERNEL_TYPE_SET_I32,
     GGML_METAL_KERNEL_TYPE_SET_F32,
     GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
@@ -1018,6 +1032,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H192,         flash_attn_ext_f16_h192,         has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128,  flash_attn_ext_f16_hk192_hv128,  has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,         flash_attn_ext_f16_h256,         has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512,  flash_attn_ext_f16_hk576_hv512,  has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64,         flash_attn_ext_bf16_h64,         has_simdgroup_mm && use_bfloat);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80,         flash_attn_ext_bf16_h80,         has_simdgroup_mm && use_bfloat);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96,         flash_attn_ext_bf16_h96,         has_simdgroup_mm && use_bfloat);
@@ -1026,6 +1041,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H192,        flash_attn_ext_bf16_h192,        has_simdgroup_mm && use_bfloat);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128, flash_attn_ext_bf16_hk192_hv128, has_simdgroup_mm && use_bfloat);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256,        flash_attn_ext_bf16_h256,        has_simdgroup_mm && use_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512, flash_attn_ext_bf16_hk576_hv512, has_simdgroup_mm && use_bfloat);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64,         flash_attn_ext_q4_0_h64,         has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80,         flash_attn_ext_q4_0_h80,         has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96,         flash_attn_ext_q4_0_h96,         has_simdgroup_mm);
@@ -1034,6 +1050,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H192,        flash_attn_ext_q4_0_h192,        has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128, flash_attn_ext_q4_0_hk192_hv128, has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256,        flash_attn_ext_q4_0_h256,        has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512, flash_attn_ext_q4_0_hk576_hv512, has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64,         flash_attn_ext_q4_1_h64,         has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80,         flash_attn_ext_q4_1_h80,         has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96,         flash_attn_ext_q4_1_h96,         has_simdgroup_mm);
@@ -1042,6 +1059,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H192,        flash_attn_ext_q4_1_h192,        has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128, flash_attn_ext_q4_1_hk192_hv128, has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256,        flash_attn_ext_q4_1_h256,        has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512, flash_attn_ext_q4_1_hk576_hv512, has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64,         flash_attn_ext_q5_0_h64,         has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80,         flash_attn_ext_q5_0_h80,         has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96,         flash_attn_ext_q5_0_h96,         has_simdgroup_mm);
@@ -1050,6 +1068,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H192,        flash_attn_ext_q5_0_h192,        has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128, flash_attn_ext_q5_0_hk192_hv128, has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256,        flash_attn_ext_q5_0_h256,        has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512, flash_attn_ext_q5_0_hk576_hv512, has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64,         flash_attn_ext_q5_1_h64,         has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80,         flash_attn_ext_q5_1_h80,         has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96,         flash_attn_ext_q5_1_h96,         has_simdgroup_mm);
@@ -1058,6 +1077,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H192,        flash_attn_ext_q5_1_h192,        has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128, flash_attn_ext_q5_1_hk192_hv128, has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256,        flash_attn_ext_q5_1_h256,        has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512, flash_attn_ext_q5_1_hk576_hv512, has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64,         flash_attn_ext_q8_0_h64,         has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80,         flash_attn_ext_q8_0_h80,         has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96,         flash_attn_ext_q8_0_h96,         has_simdgroup_mm);
@@ -1066,6 +1086,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192,        flash_attn_ext_q8_0_h192,        has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128, flash_attn_ext_q8_0_hk192_hv128, has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,        flash_attn_ext_q8_0_h256,        has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512, flash_attn_ext_q8_0_hk576_hv512, has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96,      flash_attn_ext_vec_f16_h96,      has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96,     flash_attn_ext_vec_bf16_h96,     has_simdgroup_reduction && use_bfloat);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96,     flash_attn_ext_vec_q4_0_h96,     has_simdgroup_reduction);
@@ -1101,6 +1122,13 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256,    flash_attn_ext_vec_q5_0_h256,    has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256,    flash_attn_ext_vec_q5_1_h256,    has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256,    flash_attn_ext_vec_q8_0_h256,    has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512,     flash_attn_ext_vec_f16_hk576_hv512,     has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK576_HV512,    flash_attn_ext_vec_bf16_hk576_hv512,    has_simdgroup_reduction && use_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK576_HV512,    flash_attn_ext_vec_q4_0_hk576_hv512,    has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK576_HV512,    flash_attn_ext_vec_q4_1_hk576_hv512,    has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512,    flash_attn_ext_vec_q5_0_hk576_hv512,    has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512,    flash_attn_ext_vec_q5_1_hk576_hv512,    has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512,    flash_attn_ext_vec_q8_0_hk576_hv512,    has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_F32,                         set_f32,                         true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_I32,                         set_i32,                         true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32,                     cpy_f32_f32,                     true);
@@ -1365,6 +1393,11 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
                 // TODO: not sure if it is worth adding kernels for this size
                 return false;
             }
+            if (op->src[0]->ne[0] == 576) {
+                // DeepSeek sizes
+                // TODO: disabled for now, until optmized
+                return false;
+            }
             if (op->src[1]->type != op->src[2]->type) {
                 return false;
             }
@@ -3857,12 +3890,14 @@ static void ggml_metal_encode_node(
                 // TODO: add vec kernels for (ne00%64 == 0) and maybe also for (ne00%32 == 0)
                 //       for now avoiding mainly to keep the number of templates/kernels a bit lower
                 //       these are now trivial to add after: https://github.com/ggml-org/llama.cpp/pull/12612
-                if (ne01 >= 4 || (ne00%128 != 0 && ne00 != 96 && ne00 != 192)) {
+                if (ne01 >= 4 || (ne00%128 != 0 && ne00 != 96 && ne00 != 192 && ne00 != 576)) {
                     switch (src1->type) {
                         case GGML_TYPE_F16:
                             {
                                 if (ne00 == 192 && ne20 == 128) {
                                     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128].pipeline;
+                                } else if (ne00 == 576 && ne20 == 512) {
+                                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512].pipeline;
                                 } else {
                                     switch (ne00) {
                                         case 64:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
@@ -3885,6 +3920,8 @@ static void ggml_metal_encode_node(
                             {
                                 if (ne00 == 192 && ne20 == 128) {
                                     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128].pipeline;
+                                } else if (ne00 == 576 && ne20 == 512) {
+                                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512].pipeline;
                                 } else {
                                     switch (ne00) {
                                         case 64:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64 ].pipeline; break;
@@ -3907,6 +3944,8 @@ static void ggml_metal_encode_node(
                             {
                                 if (ne00 == 192 && ne20 == 128) {
                                     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128].pipeline;
+                                } else if (ne00 == 576 && ne20 == 512) {
+                                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512].pipeline;
                                 } else {
                                     switch (ne00) {
                                         case 64:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64 ].pipeline; break;
@@ -3929,6 +3968,8 @@ static void ggml_metal_encode_node(
                             {
                                 if (ne00 == 192 && ne20 == 128) {
                                     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128].pipeline;
+                                } else if (ne00 == 576 && ne20 == 512) {
+                                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512].pipeline;
                                 } else {
                                     switch (ne00) {
                                         case 64:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64 ].pipeline; break;
@@ -3951,6 +3992,8 @@ static void ggml_metal_encode_node(
                             {
                                 if (ne00 == 192 && ne20 == 128) {
                                     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128].pipeline;
+                                } else if (ne00 == 576 && ne20 == 512) {
+                                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512].pipeline;
                                 } else {
                                     switch (ne00) {
                                         case 64:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64 ].pipeline; break;
@@ -3973,6 +4016,8 @@ static void ggml_metal_encode_node(
                             {
                                 if (ne00 == 192 && ne20 == 128) {
                                     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128].pipeline;
+                                } else if (ne00 == 576 && ne20 == 512) {
+                                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512].pipeline;
                                 } else {
                                     switch (ne00) {
                                         case 64:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64 ].pipeline; break;
@@ -3995,6 +4040,8 @@ static void ggml_metal_encode_node(
                             {
                                 if (ne00 == 192 && ne20 == 128) {
                                     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128].pipeline;
+                                } else if (ne00 == 576 && ne20 == 512) {
+                                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512].pipeline;
                                 } else {
                                     switch (ne00) {
                                         case 64:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64 ].pipeline; break;
@@ -4114,12 +4161,36 @@ static void ggml_metal_encode_node(
                                         }
                                 }
                             } break;
+                        case 576:
+                            {
+                                if (ne20 == 512) {
+                                    switch (src1->type) {
+                                        case GGML_TYPE_F16:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512].pipeline; break;
+                                        case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK576_HV512].pipeline; break;
+                                        case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK576_HV512].pipeline; break;
+                                        case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK576_HV512].pipeline; break;
+                                        case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512].pipeline; break;
+                                        case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512].pipeline; break;
+                                        case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512].pipeline; break;
+                                        default:
+                                            {
+                                                GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
+                                                GGML_LOG_ERROR("add template specialization for this type\n");
+                                                GGML_ABORT("add template specialization for this type");
+                                            }
+                                    }
+                                } else {
+                                    GGML_LOG_ERROR("unsupported size: %lld\n", ne20);
+                                    GGML_LOG_ERROR("add template specialization for this size\n");
+                                    GGML_ABORT("add template specialization for this size");
+                                }
+                            } break;
                         default:
-                                  {
-                                      GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
-                                      GGML_LOG_ERROR("add template specialization for this size\n");
-                                      GGML_ABORT("add template specialization for this size");
-                                  }
+                            {
+                                GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
+                                GGML_LOG_ERROR("add template specialization for this size\n");
+                                GGML_ABORT("add template specialization for this size");
+                            }
                     }
                 }
 
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
index 0fb28238a..dc7eab03e 100644
--- a/ggml/src/ggml-metal/ggml-metal.metal
+++ b/ggml/src/ggml-metal/ggml-metal.metal
@@ -3546,6 +3546,7 @@ template [[host_name("kernel_flash_attn_ext_f16_h128")]]         kernel flash_at
 template [[host_name("kernel_flash_attn_ext_f16_h192")]]         kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_f16_hk192_hv128")]]  kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_f16_h256")]]         kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_f16_hk576_hv512")]]  kernel flash_attn_ext_t kernel_flash_attn_ext;
 
 #if defined(GGML_METAL_USE_BF16)
 template [[host_name("kernel_flash_attn_ext_bf16_h64" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
@@ -3556,6 +3557,7 @@ template [[host_name("kernel_flash_attn_ext_bf16_h128")]]        kernel flash_at
 template [[host_name("kernel_flash_attn_ext_bf16_h192")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_bf16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_bf16_h256")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_bf16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
 #endif
 
 template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
@@ -3566,6 +3568,7 @@ template [[host_name("kernel_flash_attn_ext_q4_0_h128")]]        kernel flash_at
 template [[host_name("kernel_flash_attn_ext_q4_0_h192")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_q4_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_q4_0_h256")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q4_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
 
 template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
@@ -3575,6 +3578,7 @@ template [[host_name("kernel_flash_attn_ext_q4_1_h128")]]        kernel flash_at
 template [[host_name("kernel_flash_attn_ext_q4_1_h192")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_q4_1_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_q4_1_h256")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q4_1_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
 
 template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
@@ -3584,6 +3588,7 @@ template [[host_name("kernel_flash_attn_ext_q5_0_h128")]]        kernel flash_at
 template [[host_name("kernel_flash_attn_ext_q5_0_h192")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_q5_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_q5_0_h256")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q5_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
 
 template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
@@ -3593,6 +3598,7 @@ template [[host_name("kernel_flash_attn_ext_q5_1_h128")]]        kernel flash_at
 template [[host_name("kernel_flash_attn_ext_q5_1_h192")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_q5_1_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_q5_1_h256")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q5_1_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
 
 template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
@@ -3602,6 +3608,7 @@ template [[host_name("kernel_flash_attn_ext_q8_0_h128")]]        kernel flash_at
 template [[host_name("kernel_flash_attn_ext_q8_0_h192")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_q8_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_q8_0_h256")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q8_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
 
 #undef FA_TYPES
 
@@ -4009,6 +4016,16 @@ template [[host_name("kernel_flash_attn_ext_vec_q5_0_h256")]] kernel flash_attn_
 template [[host_name("kernel_flash_attn_ext_vec_q5_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
 template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
 
+template [[host_name("kernel_flash_attn_ext_vec_f16_hk576_hv512")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+#if defined(GGML_METAL_USE_BF16)
+template [[host_name("kernel_flash_attn_ext_vec_bf16_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+#endif
+template [[host_name("kernel_flash_attn_ext_vec_q4_0_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q4_1_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q5_0_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q5_1_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q8_0_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+
 #undef FA_TYPES
 
 template
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
index 0e9b2e813..2f6d03c93 100644
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
@@ -9261,6 +9261,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
                 case 112:
                 case 128:
                 case 256:
+                case 575: // DeepSeek MLA
                     break;
                 default:
                     return false;
diff --git a/src/llama-context.cpp b/src/llama-context.cpp
index d3ef1cbde..983385f86 100644
--- a/src/llama-context.cpp
+++ b/src/llama-context.cpp
@@ -484,7 +484,7 @@ ggml_tensor * llama_context::build_rope_shift(
 
     // See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly.
     // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
-    const float yarn_attn_factor_scaled = model.arch == LLM_ARCH_DEEPSEEK2 ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)) : cparams.yarn_attn_factor;
+    const float yarn_attn_factor = model.arch == LLM_ARCH_DEEPSEEK2 ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)) : cparams.yarn_attn_factor;
 
     ggml_tensor * tmp;
 
@@ -504,14 +504,14 @@ ggml_tensor * llama_context::build_rope_shift(
 
         tmp = ggml_rope_ext_inplace(ctx0, tmp,
                 shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                yarn_ext_factor, yarn_attn_factor_scaled, yarn_beta_fast, yarn_beta_slow);
+                yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
 
         tmp = ggml_cpy(ctx0, tmp, cur);
     } else {
         // we rotate only the first n_rot dimensions
         tmp = ggml_rope_ext_inplace(ctx0, cur,
                 shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                yarn_ext_factor, yarn_attn_factor_scaled, yarn_beta_fast, yarn_beta_slow);
+                yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
     }
 
     return tmp;
@@ -2278,11 +2278,6 @@ llama_context * llama_init_from_model(
         params.flash_attn = false;
     }
 
-    if (params.flash_attn && model->arch == LLM_ARCH_DEEPSEEK2) {
-        LLAMA_LOG_WARN("%s: flash_attn is not compatible with Deepseek2 - forcing off\n", __func__);
-        params.flash_attn = false;
-    }
-
     if (ggml_is_quantized(params.type_v) && !params.flash_attn) {
         LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__);
         return nullptr;
diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp
index 5d0222b98..a85e97288 100644
--- a/src/llama-graph.cpp
+++ b/src/llama-graph.cpp
@@ -1200,9 +1200,6 @@ ggml_tensor * llm_graph_context::build_attn_mha(
   //const auto & n_embd_head_k = hparams.n_embd_head_k;
   //const auto & n_embd_head_v = hparams.n_embd_head_v;
 
-    // note: for MLA with the absorption optimization, the final embedding size will be changed via v_mla
-    const auto n_embd_head_v = v_mla == nullptr ? v_trans ? v->ne[1] : v->ne[0] : v_mla->ne[1];
-
     const auto n_tokens = q->ne[1];
     const auto n_head   = q->ne[2];
     const auto n_kv     = k->ne[1];
@@ -1231,7 +1228,12 @@ ggml_tensor * llm_graph_context::build_attn_mha(
 
         ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
 
-        cur = ggml_reshape_2d(ctx0, cur, n_embd_head_v*n_head, n_tokens);
+        if (v_mla) {
+            cur = ggml_reshape_4d(ctx0, cur, v_mla->ne[0], 1, n_head, n_tokens);
+            cur = ggml_mul_mat(ctx0, v_mla, cur);
+        }
+
+        cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
     } else {
         ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
 
@@ -1274,9 +1276,9 @@ ggml_tensor * llm_graph_context::build_attn_mha(
             kqv = ggml_mul_mat(ctx0, v_mla, kqv);
         }
 
-        ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
+        cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
 
-        cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens);
+        cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
 
         if (!cparams.offload_kqv) {
             // all nodes between the KV store and the attention output are run on the CPU
diff --git a/src/llama-model.cpp b/src/llama-model.cpp
index 248c61748..6b7bfecf3 100644
--- a/src/llama-model.cpp
+++ b/src/llama-model.cpp
@@ -10050,7 +10050,7 @@ struct llm_build_deepseek2 : public llm_graph_context {
         // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
         const float mscale = attn_factor * (1.0f + hparams.rope_yarn_log_mul * logf(1.0f / freq_scale));
         const float kq_scale = 1.0f*mscale*mscale/sqrtf(float(n_embd_head_k));
-        const float attn_factor_scaled = 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale));
+        const float attn_factor = 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale));
 
         ggml_tensor * cur;
         ggml_tensor * inpL;
@@ -10127,13 +10127,13 @@ struct llm_build_deepseek2 : public llm_graph_context {
 
                 q_pe = ggml_rope_ext(ctx0, q_pe, inp_pos, nullptr,
                         n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                        ext_factor, attn_factor_scaled, beta_fast, beta_slow
+                        ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(q_pe, "q_pe", il);
 
                 k_pe = ggml_rope_ext(ctx0, k_pe, inp_pos, nullptr,
                         n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                        ext_factor, attn_factor_scaled, beta_fast, beta_slow
+                        ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(k_pe, "k_pe", il);
 
diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp
index 1ee742894..5f6f87d1a 100644
--- a/tests/test-backend-ops.cpp
+++ b/tests/test-backend-ops.cpp
@@ -4428,10 +4428,11 @@ static std::vector> make_test_cases_eval() {
     test_cases.emplace_back(new test_timestep_embedding());
     test_cases.emplace_back(new test_leaky_relu());
 
-    for (int hsk : { 64, 80, 128, 192, 256, }) {
-        for (int hsv : { 64, 80, 128, 192, 256, }) {
-            if (hsk != 192 && hsk != hsv) continue;
+    for (int hsk : { 64, 80, 128, 192, 256, 576 }) {
+        for (int hsv : { 64, 80, 128, 192, 256, 512 }) {
+            if (hsk != 192 && hsk != 576 && hsk != hsv) continue;
             if (hsk == 192 && (hsv != 128 && hsv != 192)) continue;
+            if (hsk == 576 && hsv != 512) continue; // DeepSeek MLA
 
             for (bool mask : { true, false } ) {
                 for (float max_bias : { 0.0f, 8.0f }) {