From 2dabf759e7c8c827d38cdd18d0792070e6a4f4e1 Mon Sep 17 00:00:00 2001 From: dm4 Date: Tue, 8 Apr 2025 21:49:13 +0800 Subject: [PATCH 01/20] llava: add more helper functions to check projector types in clip context (#12824) Signed-off-by: dm4 --- examples/llava/clip.cpp | 9 +++++++++ examples/llava/clip.h | 2 ++ 2 files changed, 11 insertions(+) diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp index 1145c816c..e9520f3d1 100644 --- a/examples/llava/clip.cpp +++ b/examples/llava/clip.cpp @@ -2840,10 +2840,19 @@ int clip_is_minicpmv(const struct clip_ctx * ctx) { bool clip_is_glm(const struct clip_ctx * ctx) { return ctx->has_glm_projector; } + bool clip_is_qwen2vl(const struct clip_ctx * ctx) { return ctx->has_qwen2vl_merger; } +bool clip_is_llava(const struct clip_ctx * ctx) { + return ctx->has_llava_projector; +} + +bool clip_is_gemma3(const struct clip_ctx * ctx) { + return ctx->proj_type == PROJECTOR_TYPE_GEMMA3; +} + // Determine the number of encoder layers to iterate over int get_deepest_feature_layer(const struct clip_ctx * ctx) { // Get the index of the second to last layer; this is the diff --git a/examples/llava/clip.h b/examples/llava/clip.h index 783bdca3e..d806465bf 100644 --- a/examples/llava/clip.h +++ b/examples/llava/clip.h @@ -106,6 +106,8 @@ CLIP_API bool clip_model_quantize(const char * fname_inp, const char * fname_out CLIP_API int clip_is_minicpmv(const struct clip_ctx * ctx); CLIP_API bool clip_is_glm(const struct clip_ctx * ctx); CLIP_API bool clip_is_qwen2vl(const struct clip_ctx * ctx); +CLIP_API bool clip_is_llava(const struct clip_ctx * ctx); +CLIP_API bool clip_is_gemma3(const struct clip_ctx * ctx); CLIP_API int get_deepest_feature_layer(const struct clip_ctx * ctx); From 78a1ba0a4f2bfed5b8b8e312592143d22e531698 Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Tue, 8 Apr 2025 18:37:06 +0200 Subject: [PATCH 02/20] server : fix thread.join() on exit (#12831) --- examples/server/server.cpp | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 760c36464..1bf1ee876 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1705,6 +1705,8 @@ private: }; struct server_response { + bool running = true; + // for keeping track of all tasks waiting for the result std::unordered_set waiting_task_ids; @@ -1759,6 +1761,10 @@ struct server_response { while (true) { std::unique_lock lock(mutex_results); condition_results.wait(lock, [&]{ + if (!running) { + SRV_DBG("%s : queue result stop\n", __func__); + std::terminate(); // we cannot return here since the caller is HTTP code + } return !queue_results.empty(); }); @@ -1789,6 +1795,10 @@ struct server_response { } std::cv_status cr_res = condition_results.wait_for(lock, std::chrono::seconds(timeout)); + if (!running) { + SRV_DBG("%s : queue result stop\n", __func__); + std::terminate(); // we cannot return here since the caller is HTTP code + } if (cr_res == std::cv_status::timeout) { return nullptr; } @@ -1818,6 +1828,12 @@ struct server_response { } } } + + // terminate the waiting loop + void terminate() { + running = false; + condition_results.notify_all(); + } }; struct server_context { @@ -4491,9 +4507,10 @@ int main(int argc, char ** argv) { svr->new_task_queue = [¶ms] { return new httplib::ThreadPool(params.n_threads_http); }; // clean up function, to be called before exit - auto clean_up = [&svr]() { + auto clean_up = [&svr, &ctx_server]() { SRV_INF("%s: cleaning up before exit...\n", __func__); svr->stop(); + ctx_server.queue_results.terminate(); llama_backend_free(); }; @@ -4534,7 +4551,7 @@ int main(int argc, char ** argv) { if (!ctx_server.load_model(params)) { clean_up(); - // t.join(); // FIXME: see below + t.join(); LOG_ERR("%s: exiting due to model loading error\n", __func__); return 1; } @@ -4582,7 +4599,7 @@ int main(int argc, char ** argv) { ctx_server.queue_tasks.start_loop(); clean_up(); - // t.join(); // FIXME: http thread may stuck if there is an on-going request. we don't need to care about this for now as the HTTP connection will already be closed at this point, but it's better to fix this + t.join(); return 0; } From a19b5cef16d885c44c635da4a5c97113c1577de8 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 8 Apr 2025 19:54:51 +0300 Subject: [PATCH 03/20] llama : fix FA when KV cache is not used (i.e. embeddings) (#12825) * ggml : FA supports F32 V * graph : cast KV to F16 when the KV cache is not used ggml-ci * server : add test that exercises embeddings with FA enabled ggml-ci --- examples/server/tests/unit/test_embedding.py | 20 ++++++++++++++++++++ examples/server/tests/utils.py | 15 +++++++++++++++ examples/server_embd.py | 2 +- ggml/src/ggml-cpu/ops.cpp | 14 +++++++++----- ggml/src/ggml-metal/ggml-metal.m | 5 +++++ src/llama-graph.cpp | 9 +++++++++ 6 files changed, 59 insertions(+), 6 deletions(-) diff --git a/examples/server/tests/unit/test_embedding.py b/examples/server/tests/unit/test_embedding.py index 8b0eb42b0..0feb452cc 100644 --- a/examples/server/tests/unit/test_embedding.py +++ b/examples/server/tests/unit/test_embedding.py @@ -49,6 +49,26 @@ def test_embedding_multiple(): assert len(d['embedding']) > 1 +def test_embedding_multiple_with_fa(): + server = ServerPreset.bert_bge_small_with_fa() + server.pooling = 'last' + server.start() + # one of these should trigger the FA branch (i.e. context size % 256 == 0) + res = server.make_request("POST", "/v1/embeddings", data={ + "input": [ + "a "*253, + "b "*254, + "c "*255, + "d "*256, + ], + }) + assert res.status_code == 200 + assert len(res.body['data']) == 4 + for d in res.body['data']: + assert 'embedding' in d + assert len(d['embedding']) > 1 + + @pytest.mark.parametrize( "input,is_multi_prompt", [ diff --git a/examples/server/tests/utils.py b/examples/server/tests/utils.py index 30aa86609..4dc2062a8 100644 --- a/examples/server/tests/utils.py +++ b/examples/server/tests/utils.py @@ -323,6 +323,21 @@ class ServerPreset: server.server_embeddings = True return server + @staticmethod + def bert_bge_small_with_fa() -> ServerProcess: + server = ServerProcess() + server.model_hf_repo = "ggml-org/models" + server.model_hf_file = "bert-bge-small/ggml-model-f16.gguf" + server.model_alias = "bert-bge-small" + server.n_ctx = 1024 + server.n_batch = 300 + server.n_ubatch = 300 + server.n_slots = 2 + server.fa = True + server.seed = 42 + server.server_embeddings = True + return server + @staticmethod def tinyllama_infill() -> ServerProcess: server = ServerProcess() diff --git a/examples/server_embd.py b/examples/server_embd.py index 0e34c6cea..f8b0ffecd 100644 --- a/examples/server_embd.py +++ b/examples/server_embd.py @@ -15,7 +15,7 @@ async def main(): model_url = "http://127.0.0.1:6900" responses: list[requests.Response] = await asyncio.gather(*[requests_post_async( url= f"{model_url}/embedding", - json= {"content": str(0)*1024} + json= {"content": "a "*1022} ) for i in range(n)]) for response in responses: diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 7a8d5ac6f..f63656be5 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -6721,8 +6721,8 @@ static void ggml_compute_forward_flash_attn_ext_f16( ggml_vec_dot_t const kq_vec_dot = ggml_get_type_traits_cpu(k->type)->vec_dot; ggml_to_float_t const v_to_float = ggml_get_type_traits(v->type)->to_float; - GGML_ASSERT(q_to_vec_dot && "fattn: unsupported K-type"); - GGML_ASSERT(v_to_float && "fattn: unsupported V-type"); + GGML_ASSERT(( q_to_vec_dot) && "fattn: unsupported K-type"); + GGML_ASSERT((v->type == GGML_TYPE_F32 || v_to_float ) && "fattn: unsupported V-type"); // loop over n_batch and n_head for (int ir = ir0; ir < ir1; ++ir) { @@ -6818,10 +6818,14 @@ static void ggml_compute_forward_flash_attn_ext_f16( vs = expf(s - M); } - v_to_float(v_data, V32, DV); - // V += v*expf(s - M) - ggml_vec_mad_f32(DV, VKQ32, V32, vs); + if (v_to_float) { + v_to_float(v_data, V32, DV); + ggml_vec_mad_f32(DV, VKQ32, V32, vs); + } else { + // V is F32 + ggml_vec_mad_f32(DV, VKQ32, (const float *) v_data, vs); + } } S = S*ms + vs; // scale and increment sum with partial sum diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 456e1fd99..f22682602 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -1345,6 +1345,11 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex case GGML_OP_ARANGE: return true; case GGML_OP_FLASH_ATTN_EXT: + if (op->src[0]->ne[0] == 32) { + // head size == 32 (e.g. bert-bge-small) + // TODO: not sure if it is worth adding kernels for this size + return false; + } if (op->src[1]->type != op->src[2]->type) { return false; } diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index c3469177e..cd955d63b 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1215,6 +1215,15 @@ ggml_tensor * llm_graph_context::build_attn_mha( v = ggml_transpose(ctx0, v); } + // this can happen when KV cache is not used (e.g. an embedding model with non-causal attn) + if (k->type == GGML_TYPE_F32) { + k = ggml_cast(ctx0, k, GGML_TYPE_F16); + } + + if (v->type == GGML_TYPE_F32) { + v = ggml_cast(ctx0, v, GGML_TYPE_F16); + } + cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias, hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f); From b32efad2bc42460637c3a364c9554ea8217b3d7f Mon Sep 17 00:00:00 2001 From: Matt Clayton <156335168+mattjcly@users.noreply.github.com> Date: Tue, 8 Apr 2025 16:01:58 -0400 Subject: [PATCH 04/20] llava: improve clip_ctx destructor to not memleak load_image_size (#12834) --- examples/llava/clip.cpp | 10 ++++++++++ examples/llava/clip.h | 1 + 2 files changed, 11 insertions(+) diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp index e9520f3d1..44428cc95 100644 --- a/examples/llava/clip.cpp +++ b/examples/llava/clip.cpp @@ -380,6 +380,7 @@ struct clip_ctx { if (backend_cpu != backend) { ggml_backend_free(backend_cpu); } + clip_image_size_free(load_image_size); } }; @@ -1618,6 +1619,12 @@ struct clip_image_f32 * clip_image_f32_init() { return new clip_image_f32(); } +void clip_image_size_free(struct clip_image_size * load_image_size) { + if (load_image_size == nullptr) { + return; + } + delete load_image_size; +} void clip_image_u8_free(struct clip_image_u8 * img) { delete img; } void clip_image_f32_free(struct clip_image_f32 * img) { delete img; } void clip_image_u8_batch_free(struct clip_image_u8_batch * batch) { @@ -2270,6 +2277,9 @@ ggml_tensor * clip_get_newline_tensor(const struct clip_ctx * ctx) { } void clip_free(clip_ctx * ctx) { + if (ctx == nullptr) { + return; + } delete ctx; } diff --git a/examples/llava/clip.h b/examples/llava/clip.h index d806465bf..87aa61574 100644 --- a/examples/llava/clip.h +++ b/examples/llava/clip.h @@ -77,6 +77,7 @@ CLIP_API struct clip_image_size * clip_image_size_init(); CLIP_API struct clip_image_u8 * clip_image_u8_init (); CLIP_API struct clip_image_f32 * clip_image_f32_init(); +CLIP_API void clip_image_size_free (struct clip_image_size * img_size); CLIP_API void clip_image_u8_free (struct clip_image_u8 * img); CLIP_API void clip_image_f32_free(struct clip_image_f32 * img); CLIP_API void clip_image_u8_batch_free (struct clip_image_u8_batch * batch); From 7538246e7ce0606694c38055cc2fc9f60535be6c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Tue, 8 Apr 2025 23:21:31 +0200 Subject: [PATCH 05/20] cuda : add f32 to bf16 copy op (#12806) This allows BF16 KV-cache on CUDA. --- ggml/src/ggml-cuda/cpy.cu | 21 +++++++++++++++++++++ ggml/src/ggml-cuda/ggml-cuda.cu | 3 +++ 2 files changed, 24 insertions(+) diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index ed853ee6c..4f4faa3e6 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -10,6 +10,13 @@ static __device__ void cpy_1_f32_f32(const char * cxi, char * cdsti) { *dsti = *xi; } +static __device__ void cpy_1_f32_bf16(const char * cxi, char * cdsti) { + const float * xi = (const float *) cxi; + nv_bfloat16 * dsti = (nv_bfloat16 *) cdsti; + + *dsti = *xi; +} + static __device__ void cpy_1_f32_f16(const char * cxi, char * cdsti) { const float * xi = (const float *) cxi; half * dsti = (half *) cdsti; @@ -386,6 +393,16 @@ static void ggml_cpy_f32_f32_cuda( (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); } +static void ggml_cpy_f32_bf16_cuda( + const char * cx, char * cdst, const int ne, + const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { + + const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; + cpy_f32_f16<<>> + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); +} + static void ggml_cpy_f32_f16_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, @@ -581,6 +598,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream)); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { ggml_cpy_f32_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) { + ggml_cpy_f32_bf16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) { ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) { @@ -634,6 +653,8 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) { return nullptr; } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { return (void*) cpy_f32_f16; + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) { + return (void*) cpy_f32_f16; } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) { return (void*) cpy_f32_f16; } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) { diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 78717df1a..633456a92 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3079,6 +3079,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) { return true; } + if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_BF16) { + return true; + } if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) { return true; } From 7ecd780b1a1d5214b8d04c25ebfc194d310816ed Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Wed, 9 Apr 2025 00:12:57 -0500 Subject: [PATCH 06/20] vulkan: Use fp16 for the flash attention P*V multiplication (#12783) This is consistent with the ggml-cuda behavior and the mul_mat fallback. --- ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index 8ddadb8a1..a8f4bc417 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -330,9 +330,11 @@ void main() { // resize eM by using smear/reduce coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce); - O = eMdiag * O; + // multiply with fp16 accumulation, then add to O. + coopmat PV = coopmat(0); + PV = coopMatMulAdd(P_A, V, PV); - O = coopMatMulAdd(P_A, V, O); + O = eMdiag * O + coopmat(PV); } // If there is split_k, then the split_k resolve shader does the final From 0090950f679475c5ecaac2f7bca5049cca96492b Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Wed, 9 Apr 2025 00:25:08 -0500 Subject: [PATCH 07/20] vulkan: In coopmat2 mmq, load q4_k/q5_k scales through shared memory (#12833) q4_k and q5_k had a lot of redundant global loads where the same 16B of scale information is repeatedly loaded and decoded during each loop iteration. This change restructures the loops to more explicitly iterate over whole blocks in the outer loop (with unrolled inner loop) and to copy/decode the scale data into shared memory once at the start of each outer loop. The copy is pipelined so the scale load from global memory is relatively cheap. This improves q4_k/q5_k model prompt processing performance by around 5-7%. I briefly tried applying this to q6_k and q4_0, and it didn't help for q6_k and hurt for q4_0. The big "else" path in mul_mm_cm2.comp that had all the clamped/unclamped variants isn't used as often as it originally was (e.g. due to the padded_N change), so I trimmed it down to offset some of the new complexity of the semi-manual loop unrolling. --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 6 + .../vulkan-shaders/dequant_funcs_cm2.comp | 116 ++++++++++++- .../vulkan-shaders/mul_mm_cm2.comp | 155 +++++++++++++----- 3 files changed, 235 insertions(+), 42 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 705a6135a..e69d00ad5 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -4194,6 +4194,12 @@ static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, int m, int if (split_k == 3) { split_k = 2; } + if (ctx->device->coopmat2) { + // coopmat2 shader expects splits to be aligned to 256 + while (split_k > 1 && ((k / split_k) % 256) != 0) { + split_k /= 2; + } + } } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp index b3fad35e2..962d2353f 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp @@ -167,6 +167,101 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4 block_q4_K_packed128 block; }; +#if defined(IS_MUL_MM2) + +// For Q4_K and Q5_K in the mat-mul shader, we decode a tile's worth of scales +// into shared memory and then process the whole tile using those scales. +// There is a fetch function that loads into private variables and then a store +// function that stores into shared memory. +// Q4_K and Q5_K have the same encoding of scales, so everything is shared except +// the part that fetches from the structure (which has a different block layout). +#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K) +const uint shAscales_stride = (BM + 2); +// 1 scale per 32 elements -> 8 scales per block, per row +shared vec2 shAscales[8 * shAscales_stride]; +uvec4 row_v; +#endif + +#if defined(DATA_A_Q4_K) +layout (binding = 0) readonly buffer A_Q4_K_128 {block_q4_K_packed128 data_a_q4_k_packed128[];}; + +void fetch_scalesQ4_K(uint ir_BM, uint pos_a, uint stride_a, uint block_k, uint tid, bool in_bounds) +{ + uint tids_per_row = BLOCK_SIZE / BM; + uint is_per_tid = 8 / tids_per_row; + uint is_start = is_per_tid * (tid % tids_per_row); + uint tid_row = tid / tids_per_row; + + uint row = ir_BM + tid_row; + uint block_index = pos_a + row * stride_a + (block_k / QUANT_K); + if (in_bounds || row < p.M) { + row_v = data_a_q4_k_packed128[block_index].q4k[0]; + } +} +#endif +#if defined(DATA_A_Q5_K) +layout (binding = 0) readonly buffer A_Q5_K_128 {block_q5_K_packed128 data_a_q5_k_packed128[];}; + +void fetch_scalesQ5_K(uint ir_BM, uint pos_a, uint stride_a, uint block_k, uint tid, bool in_bounds) +{ + uint tids_per_row = BLOCK_SIZE / BM; + uint is_per_tid = 8 / tids_per_row; + uint is_start = is_per_tid * (tid % tids_per_row); + uint tid_row = tid / tids_per_row; + + uint row = ir_BM + tid_row; + uint block_index = pos_a + row * stride_a + (block_k / QUANT_K); + if (in_bounds || row < p.M) { + row_v = data_a_q5_k_packed128[block_index].q5k[0]; + } +} +#endif + +#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K) +void store_scalesQ4_K(uint tid) +{ + barrier(); + + uint tids_per_row = BLOCK_SIZE / BM; + uint is_per_tid = 8 / tids_per_row; + uint is_start = is_per_tid * (tid % tids_per_row); + uint tid_row = tid / tids_per_row; + + [[unroll]] for (uint idx = 0; idx < is_per_tid; ++idx) { + uint is = idx + is_start; + uvec4 v = row_v; + const vec2 loadd = vec2(unpackFloat2x16(v.x)); + + uint32_t sc; + uint32_t mbyte; + + uint32_t scale0 = v.y; + uint32_t scale4 = v.z; + uint32_t scale8 = v.w; + + uint32_t sc_lo = scale0; + uint32_t mb_lo = scale4; + uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2); + uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2); + + sc = is < 4 ? sc_lo : sc_hi; + mbyte = is < 4 ? mb_lo : mb_hi; + sc = sc >> (8 * (is & 3)); + mbyte = mbyte >> (8 * (is & 3)); + sc &= 0x3F; + mbyte &= 0x3F; + + const float d = loadd.x * float(sc); + const float m = loadd.y * float(mbyte); + shAscales[is * shAscales_stride + tid_row] = vec2(d,m); + } + + barrier(); +} +#endif + +#endif + float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) { decodeBufQ4_K_packed16 bl16 = decodeBufQ4_K_packed16(bl); @@ -176,8 +271,12 @@ float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2 const uint b = (idx & 0x20) >> 5; // 0,1 const uint is = (idx & 0xE0) >> 5; // 0..7 +#if defined(IS_MUL_MM2) && defined(DATA_A_Q4_K) + vec2 v = shAscales[is * shAscales_stride + (blockCoords[0] % BM)]; + float d = v.x; + float m = v.y; +#else uvec4 v = bl128.block.q4k[0]; - const vec2 loadd = vec2(unpackFloat2x16(v.x)); uint32_t sc; @@ -201,6 +300,7 @@ float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2 const float d = loadd.x * float(sc); const float m = loadd.y * float(mbyte); +#endif uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]); qs = (qs >> (b * 4 + 8 * (idx & 1))) & 0xF; @@ -231,6 +331,11 @@ float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2 const uint b = (idx & 0x20) >> 5; // 0,1 const uint is = (idx & 0xE0) >> 5; // 0..7 +#if defined(IS_MUL_MM2) && defined(DATA_A_Q5_K) + vec2 v = shAscales[is * shAscales_stride + (blockCoords[0] % BM)]; + float d = v.x; + float m = v.y; +#else uvec4 v = bl128.block.q5k[0]; const f16vec2 loadd = unpackFloat2x16(v.x); @@ -256,6 +361,7 @@ float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2 const float16_t d = loadd.x * float16_t(sc); const float16_t m = loadd.y * float16_t(mbyte); +#endif uint qh = uint32_t(bl16.block.qh[(idx & 0x1E) >> 1]); qh = ((qh >> is) & 0x101) << 4; @@ -264,9 +370,9 @@ float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2 qs = (qs >> (b * 4)) & 0x0F0F; qs = unpack8(qs | qh)[idx & 1]; - float16_t ret = d * (float16_t(qs)) - m; + float ret = d * float(qs) - m; - return ret; + return float16_t(ret); } layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ6_K { @@ -564,8 +670,12 @@ float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoor #define dequantFuncA dequantFuncQ3_K #elif defined(DATA_A_Q4_K) #define dequantFuncA dequantFuncQ4_K +#define fetch_scales fetch_scalesQ4_K +#define store_scales store_scalesQ4_K #elif defined(DATA_A_Q5_K) #define dequantFuncA dequantFuncQ5_K +#define fetch_scales fetch_scalesQ5_K +#define store_scales store_scalesQ4_K #elif defined(DATA_A_Q6_K) #define dequantFuncA dequantFuncQ6_K #elif defined(DATA_A_IQ1_S) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp index 7649febb0..06b7ab09e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp @@ -19,6 +19,9 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; +#define IS_MUL_MM2 1 + +layout (constant_id = 0) const uint BLOCK_SIZE = 256; layout (constant_id = 1) const uint BM = 64; layout (constant_id = 2) const uint BN = 64; layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant @@ -70,6 +73,13 @@ layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; #define DECODEFUNCA #endif +#if !defined(fetch_scales) +#define fetch_scales(a, b, c, d, e, f) +#endif +#if !defined(store_scales) +#define store_scales(a) +#endif + #ifdef MUL_MAT_ID layout (binding = 3) readonly buffer IDS {int data_ids[];}; @@ -116,6 +126,8 @@ void main() { init_iq_shmem(gl_WorkGroupSize); #endif + const uint tid = gl_LocalInvocationIndex; + #ifdef MUL_MAT_ID const uint expert_idx = gl_GlobalInvocationID.z; #else @@ -218,14 +230,21 @@ void main() { tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0); #if !defined(MUL_MAT_ID) + + const uint START_ALIGN_K = 256; + // For Qi_K (block size 256), unroll whole 256 element tiles. + // For legacy quants (block size 32), unroll 8x. + const uint UNROLL_K = (QUANT_K == 256) ? 256 : (BK * 8); + const uint unroll_count = UNROLL_K / BK; + // Detect a fast path where all loads are entirely in bounds and no clamping is required - if ((ir + 1) * BM <= p.M && (ic + 1) * BN <= p.padded_N && (start_k % BK) == 0 && (end_k % BK) == 0 && + if ((ir + 1) * BM <= p.M && (ic + 1) * BN <= p.padded_N && (start_k % START_ALIGN_K) == 0 && (end_k % BK) == 0 && #if QUANT_K == 1 (stride_a % 8) == 0 && #endif - (stride_b % 8) == 0 && (start_k % 8) == 0) { + (stride_b % 8) == 0) { // Hint to the compiler that values are aligned (want 16B alignment) - start_k &= ~7; + start_k &= ~(START_ALIGN_K-1); stride_b &= ~7; #if QUANT_K == 1 stride_a &= ~7; @@ -234,11 +253,39 @@ void main() { tensorLayoutA = setTensorLayoutStrideNV(tensorLayoutA, stride_a, 1); tensorLayoutB = setTensorLayoutStrideNV(tensorLayoutB, stride_b, 1); - uint k_iters = (end_k - start_k + BK - 1) / BK; + uint k_iters = (end_k - start_k) / UNROLL_K; + uint block_k = start_k; + + // fetch scale values for a tile of quants. These will be copied into shared memory. + // The fetches and stores are pipelined to hide the latency. + fetch_scales(ir * BM, pos_a, stride_a, start_k, tid, true); + if (enable_smaller_matrices && ic * BN + BNover4 >= p.N) { coopmat sum = coopmat(0.0); - for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) { + for (uint i = 0; i < k_iters; ++i) { + store_scales(tid); + if (block_k + UNROLL_K < end_k) { + fetch_scales(ir * BM, pos_a, stride_a, block_k + UNROLL_K, tid, true); + } + + // Manually partial unroll + [[unroll]] for (uint j = 0; j < unroll_count; ++j) { + coopmat mat_a; + coopmat mat_b; + + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose); + + sum = coopMatMulAdd(mat_a, mat_b, sum); + block_k += BK; + } + } + // Do any remaining iterations that were not unrolled + if (block_k < end_k) { + store_scales(tid); + } + while (block_k < end_k) { coopmat mat_a; coopmat mat_b; @@ -246,6 +293,7 @@ void main() { coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose); sum = coopMatMulAdd(mat_a, mat_b, sum); + block_k += BK; } coopmat mat_d = coopmat(sum); @@ -253,8 +301,30 @@ void main() { return; } else if (enable_smaller_matrices && ic * BN + BNover2 >= p.N) { coopmat sum = coopmat(0.0); - for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) { + for (uint i = 0; i < k_iters; ++i) { + store_scales(tid); + if (block_k + UNROLL_K < end_k) { + fetch_scales(ir * BM, pos_a, stride_a, block_k + UNROLL_K, tid, true); + } + + // Manually partial unroll + [[unroll]] for (uint j = 0; j < unroll_count; ++j) { + coopmat mat_a; + coopmat mat_b; + + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose); + + sum = coopMatMulAdd(mat_a, mat_b, sum); + block_k += BK; + } + } + // Do any remaining iterations that were not unrolled + if (block_k < end_k) { + store_scales(tid); + } + while (block_k < end_k) { coopmat mat_a; coopmat mat_b; @@ -262,6 +332,7 @@ void main() { coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose); sum = coopMatMulAdd(mat_a, mat_b, sum); + block_k += BK; } coopmat mat_d = coopmat(sum); @@ -269,8 +340,31 @@ void main() { return; } else { coopmat sum = coopmat(0.0); - for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) { + for (uint i = 0; i < k_iters; ++i) { + + store_scales(tid); + if (block_k + UNROLL_K < end_k) { + fetch_scales(ir * BM, pos_a, stride_a, block_k + UNROLL_K, tid, true); + } + + // Manually partial unroll + [[unroll]] for (uint j = 0; j < unroll_count; ++j) { + coopmat mat_a; + coopmat mat_b; + + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose); + + sum = coopMatMulAdd(mat_a, mat_b, sum); + block_k += BK; + } + } + // Do any remaining iterations that were not unrolled + if (block_k < end_k) { + store_scales(tid); + } + while (block_k < end_k) { coopmat mat_a; coopmat mat_b; @@ -278,6 +372,7 @@ void main() { coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose); sum = coopMatMulAdd(mat_a, mat_b, sum); + block_k += BK; } coopmat mat_d = coopmat(sum); @@ -298,47 +393,29 @@ void main() { coopmat sum; sum = coopmat(0.0); + uint k_iters = (end_k - start_k + BK - 1) / BK; + + fetch_scales(ir * BM, pos_a, stride_a, start_k, tid, false); + [[dont_unroll]] - for (uint block_k = start_k; block_k < end_k; block_k += BK) { + for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) { + + store_scales(tid); + if (block_k + BK < end_k) { + fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false); + } coopmat mat_a; coopmat mat_b; - // Clamping is expensive, so detect different code paths for each combination - // of A and B needing clamping. - bool unclampedA = (ir + 1) * BM <= p.M && block_k + BK <= end_k && (block_k % 8) == 0; + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA); #ifdef MUL_MAT_ID - bool unclampedB = true; + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB); #else - bool unclampedB = (ic + 1) * BN <= p.padded_N && block_k + BK <= end_k && (block_k % 8) == 0; + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose); #endif - if (unclampedA && unclampedB) { - coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, (block_k & ~7), BK) DECODEFUNCA); -#ifdef MUL_MAT_ID - coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB); -#else - coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, (block_k & ~7), BK), tensorViewTranspose); -#endif - sum = coopMatMulAdd(mat_a, mat_b, sum); - } else if (unclampedA && !unclampedB) { - coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, (block_k & ~7), BK) DECODEFUNCA); - coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose); - sum = coopMatMulAdd(mat_a, mat_b, sum); - } else if (!unclampedA && unclampedB) { - coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA); -#ifdef MUL_MAT_ID - coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB); -#else - coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, (block_k & ~7), BK), tensorViewTranspose); -#endif - sum = coopMatMulAdd(mat_a, mat_b, sum); - } else if (!unclampedA && !unclampedB) { - coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA); - coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose); - - sum = coopMatMulAdd(mat_a, mat_b, sum); - } + sum = coopMatMulAdd(mat_a, mat_b, sum); } // Convert from ACC_TYPE to D_TYPE From 6e1c4cebdb697f925c523d3a969128d945161bdd Mon Sep 17 00:00:00 2001 From: Chenguang Li <757486878@qq.com> Date: Wed, 9 Apr 2025 14:04:14 +0800 Subject: [PATCH 08/20] CANN: Support Opt CONV_TRANSPOSE_1D and ELU (#12786) * [CANN] Support ELU and CONV_TRANSPOSE_1D * [CANN]Modification review comments * [CANN]Modification review comments * [CANN]name adjustment * [CANN]remove lambda used in template * [CANN]Use std::func instead of template * [CANN]Modify the code according to the review comments --------- Signed-off-by: noemotiovon --- .devops/llama-cli-cann.Dockerfile | 4 +- .github/workflows/build.yml | 4 +- ggml/src/ggml-cann/aclnn_ops.cpp | 62 ++++++++++ ggml/src/ggml-cann/aclnn_ops.h | 189 +++++++++++++++++++----------- ggml/src/ggml-cann/ggml-cann.cpp | 23 ++-- 5 files changed, 204 insertions(+), 78 deletions(-) diff --git a/.devops/llama-cli-cann.Dockerfile b/.devops/llama-cli-cann.Dockerfile index 02dce501c..0eb1af87c 100644 --- a/.devops/llama-cli-cann.Dockerfile +++ b/.devops/llama-cli-cann.Dockerfile @@ -1,4 +1,4 @@ -ARG ASCEND_VERSION=8.0.rc2.alpha003-910b-openeuler22.03-py3.8 +ARG ASCEND_VERSION=8.1.RC1.alpha001-910b-openeuler22.03-py3.10 FROM ascendai/cann:$ASCEND_VERSION AS build @@ -6,7 +6,7 @@ WORKDIR /app COPY . . -RUN yum install -y gcc g++ cmake make +RUN yum install -y gcc g++ cmake make libcurl-devel ENV ASCEND_TOOLKIT_HOME=/usr/local/Ascend/ascend-toolkit/latest ENV LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/lib64:$LIBRARY_PATH ENV LD_LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/lib64:${ASCEND_TOOLKIT_HOME}/lib64/plugin/opskernel:${ASCEND_TOOLKIT_HOME}/lib64/plugin/nnengine:${ASCEND_TOOLKIT_HOME}/opp/built-in/op_impl/ai_core/tbe/op_tiling:${LD_LIBRARY_PATH} diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 33f6a4fb4..bcfcf08ac 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -1771,7 +1771,7 @@ jobs: strategy: matrix: cann: - - '8.0.rc3.beta1-910b-openeuler22.03-py3.10' + - '8.1.RC1.alpha001-910b-openeuler22.03-py3.10' device: - 'ascend910b3' build: @@ -1784,7 +1784,7 @@ jobs: - name: Dependencies run: | yum update -y - yum install -y git gcc gcc-c++ make cmake + yum install -y git gcc gcc-c++ make cmake libcurl-devel - name: Build run: | diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index fee66aea9..25b2599c7 100644 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -57,6 +57,8 @@ #include #include #include +#include +#include #include #include @@ -86,6 +88,20 @@ void bcast_shape(ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst, aclT } } +void ggml_cann_unary_op( + std::function unary_op, + ggml_backend_cann_context& ctx, ggml_tensor* dst) { + ggml_tensor* src = dst->src[0]; + + aclTensor* acl_src = ggml_cann_create_tensor(src); + aclTensor* acl_dst = ggml_cann_create_tensor(dst); + + unary_op(ctx, acl_src, acl_dst); + + ACL_CHECK(aclDestroyTensor(acl_src)); + ACL_CHECK(aclDestroyTensor(acl_dst)); +} + /** * @brief Repeats elements of a tensor along each dimension according to the * specified repeat array. @@ -2585,3 +2601,49 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) { ACL_CHECK(aclDestroyTensor(acl_src)); ACL_CHECK(aclDestroyTensor(acl_dst)); } + +void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* dst){ + ggml_tensor * src0 = dst->src[0]; + ggml_tensor * src1 = dst->src[1]; + + // stride + int64_t s0 = ((const int32_t*)(dst->op_params))[0]; + + aclTensor* acl_input = ggml_cann_create_tensor(src1, src1->ne, src1->nb, 3, ACL_FORMAT_NCL); + aclTensor* acl_weight = ggml_cann_create_tensor(src0, src0->ne, src0->nb, 3, ACL_FORMAT_NCL); + aclTensor* acl_dst = ggml_cann_create_tensor(dst, dst->ne, dst->nb, 3, ACL_FORMAT_NCL); + + int64_t strideVal[1]; + strideVal[0] = s0; + aclIntArray *stride = aclCreateIntArray(strideVal, 1); + int64_t paddingVal[] = {0}; + aclIntArray *padding = aclCreateIntArray(paddingVal, 1); + int64_t dilationVal[] = {1}; + aclIntArray *dilation = aclCreateIntArray(dilationVal, 1); + bool transposed = true; + int64_t groups = 1; + int8_t cubeMathType = 0; + + GGML_CANN_CALL_ACLNN_OP(Convolution, acl_input, acl_weight, nullptr, stride, + padding, dilation, transposed, padding, groups, acl_dst, cubeMathType); + + ACL_CHECK(aclDestroyTensor(acl_weight)); + ACL_CHECK(aclDestroyTensor(acl_dst)); +} + +void ggml_cann_elu(ggml_backend_cann_context& ctx, ggml_tensor* dst){ + ggml_tensor * src0 = dst->src[0]; + + aclTensor* acl_input = ggml_cann_create_tensor(src0); + aclTensor* acl_dst = ggml_cann_create_tensor(dst); + + float alphaValue = 1.0f; + aclScalar* alpha = nullptr; + alpha = aclCreateScalar(&alphaValue, aclDataType::ACL_FLOAT); + + GGML_CANN_CALL_ACLNN_OP(Elu, acl_input, alpha, alpha, alpha, + acl_dst); + + ACL_CHECK(aclDestroyTensor(acl_input)); + ACL_CHECK(aclDestroyTensor(acl_dst)); +} diff --git a/ggml/src/ggml-cann/aclnn_ops.h b/ggml/src/ggml-cann/aclnn_ops.h index 116ddf0fb..aadf013de 100644 --- a/ggml/src/ggml-cann/aclnn_ops.h +++ b/ggml/src/ggml-cann/aclnn_ops.h @@ -1,15 +1,4 @@ -#ifndef CANN_ACLNN_OPS -#define CANN_ACLNN_OPS - /** - * @file acl_tensor - * @brief This file contains related functions of ggml_tensor and acl_tensor. - * Contains conversion from ggml_tensor to acl_tensor, broadcast and other - * functions. - * @author hipudding - * @author wangshuai09 <391746016@qq.com> - * @date July 15, 2024 - * * Copyright (c) 2023-2024 The ggml authors * * Permission is hereby granted, free of charge, to any person obtaining a copy @@ -31,6 +20,9 @@ * IN THE SOFTWARE. */ +#ifndef CANN_ACLNN_OPS +#define CANN_ACLNN_OPS + #include #include #include @@ -483,8 +475,8 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst); * operation is executed using the CANN backend for optimized performance. * * @param ctx The CANN context used for operations. - * @param dst The destination tensor where the indices of the maximum values will be stored. - * dst->op is `GGML_OP_ARGMAX`. + * @param dst The destination tensor where the indices of the maximum values will + * be stored. dst->op is `GGML_OP_ARGMAX`. */ void ggml_cann_argmax(ggml_backend_cann_context& ctx, ggml_tensor* dst); @@ -599,6 +591,99 @@ void aclnn_cos(ggml_backend_cann_context& ctx, aclTensor* acl_src, void aclnn_sin(ggml_backend_cann_context& ctx, aclTensor* acl_src, aclTensor* acl_dst); +/** + * @brief Prepares broadcast-compatible ACL tensors for two input tensors and one + * output tensor. + * + * This function checks whether broadcasting is needed between `src0` and `src1`. + * If broadcasting is required, it calculates the proper shapes and creates + * ACL tensors with broadcast parameters. Otherwise, it directly creates ACL tensors + * based on the original tensor shapes. + * + * @param src0 The first input tensor (reference shape). + * @param src1 The second input tensor (possibly broadcasted). + * @param dst The destination/output tensor. + * @param acl_src0 Output pointer to the created ACL tensor corresponding to src0. + * @param acl_src1 Output pointer to the created ACL tensor corresponding to src1. + * @param acl_dst Output pointer to the created ACL tensor corresponding to dst. + */ +void bcast_shape(ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst, + aclTensor ** acl_src0, aclTensor ** acl_src1, aclTensor ** acl_dst); + +/** + * @brief Computes the 1D transposed convolution (deconvolution) of a ggml + * tensor using the CANN backend. + * + * @details This function performs a 1D transposed convolution (also known as + * deconvolution) operation on the input tensor. The computed result is stored + * in the destination tensor `dst`. The operation is optimized using the CANN + * backend for improved performance. + * + * @param ctx The CANN context used for operations. + * @param dst The destination tensor where the transposed convolution result + * will be stored. dst->op is `GGML_OP_CONV_TRANSPOSE_1D`. + */ +void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* dst); + +/** + * @brief Applies the ELU (Exponential Linear Unit) activation to a ggml tensor + * using the CANN backend. + * + * @details This function performs an element-wise ELU activation on the input + * tensor. + * The result is written to the destination tensor `dst` in-place. + * The ELU function is defined as: + * + * \text{ELU}(x) = + * \begin{cases} + * x, & \text{if } x > 0 \\ + * \alpha \left( \exp(x) - 1 \right), & \text{if } x \leq 0 + * \end{cases} + * + * where α (alpha) is a hyperparameter, typically set to 1.0. + * This operation is optimized using the CANN backend for high-performance + * inference or training. + * + * @param ctx The CANN context used for operations. + * @param dst The destination tensor where the ELU-activated result will be stored. + * dst->op is expected to be `GGML_OP_ELU`. + */ +void ggml_cann_elu(ggml_backend_cann_context& ctx, ggml_tensor* dst); + +/** + * @brief Applies a element-wise operation to two input tensors using the CANN + * backend. + * + * This templated function takes a binary operator and applies it to two source + * tensors + * associated with the destination tensor. The function handles broadcasting as + * needed. + * + * @tparam binary_op A callable object (e.g., lambda or function pointer) representing + * the binary operation to be performed. It must take three arguments: + * (ggml_backend_cann_context&, aclTensor*, aclTensor*, aclTensor*). + * + * @param ctx The CANN backend context used to manage execution and resources. + * @param dst The destination tensor. + */ +template +void ggml_cann_binary_op(ggml_backend_cann_context& ctx, ggml_tensor* dst) { + ggml_tensor* src0 = dst->src[0]; + ggml_tensor* src1 = dst->src[1]; + + aclTensor* acl_src0; + aclTensor* acl_src1; + aclTensor* acl_dst; + + // Need bcast + bcast_shape(src0, src1, dst, &acl_src0, &acl_src1, &acl_dst); + binary_op(ctx, acl_src0, acl_src1, acl_dst); + + ACL_CHECK(aclDestroyTensor(acl_src0)); + ACL_CHECK(aclDestroyTensor(acl_src1)); + ACL_CHECK(aclDestroyTensor(acl_dst)); +} + /** * @brief Launches an asynchronous task using the memory allocator. * @@ -631,56 +716,6 @@ void aclnn_sin(ggml_backend_cann_context& ctx, aclTensor* acl_src, ACL_CHECK(aclnn##OP_NAME(workspaceAddr, workspaceSize, executor, ctx.stream())); \ } while (0) - -/** - * @brief Prepares broadcast-compatible ACL tensors for two input tensors and one output tensor. - * - * This function checks whether broadcasting is needed between `src0` and `src1`. - * If broadcasting is required, it calculates the proper shapes and creates - * ACL tensors with broadcast parameters. Otherwise, it directly creates ACL tensors - * based on the original tensor shapes. - * - * @param src0 The first input tensor (reference shape). - * @param src1 The second input tensor (possibly broadcasted). - * @param dst The destination/output tensor. - * @param acl_src0 Output pointer to the created ACL tensor corresponding to src0. - * @param acl_src1 Output pointer to the created ACL tensor corresponding to src1. - * @param acl_dst Output pointer to the created ACL tensor corresponding to dst. - */ -void bcast_shape(ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst, aclTensor ** acl_src0, - aclTensor ** acl_src1, aclTensor ** acl_dst); - -/** - * @brief Applies a element-wise operation to two input tensors using the CANN backend. - * - * This templated function takes a binary operator and applies it to two source tensors - * associated with the destination tensor. The function handles broadcasting as needed. - * - * @tparam binary_op A callable object (e.g., lambda or function pointer) representing - * the binary operation to be performed. It must take three arguments: - * (ggml_backend_cann_context&, aclTensor*, aclTensor*, aclTensor*). - * - * @param ctx The CANN backend context used to manage execution and resources. - * @param dst The destination tensor. - */ -template -void ggml_cann_binary_op(ggml_backend_cann_context& ctx, ggml_tensor* dst) { - ggml_tensor* src0 = dst->src[0]; - ggml_tensor* src1 = dst->src[1]; - - aclTensor* acl_src0; - aclTensor* acl_src1; - aclTensor* acl_dst; - - // Need bcast - bcast_shape(src0, src1, dst, &acl_src0, &acl_src1, &acl_dst); - binary_op(ctx, acl_src0, acl_src1, acl_dst); - - ACL_CHECK(aclDestroyTensor(acl_src0)); - ACL_CHECK(aclDestroyTensor(acl_src1)); - ACL_CHECK(aclDestroyTensor(acl_dst)); -} - /** * @brief Applies a unary operation to an input tensor using the CANN backend. * @@ -690,7 +725,6 @@ void ggml_cann_binary_op(ggml_backend_cann_context& ctx, ggml_tensor* dst) { * @tparam unary_op A callable with the signature: * void(ggml_backend_cann_context&, aclTensor*, aclTensor*) * where the first aclTensor is the source and the second is the destination. - * * @param ctx The CANN backend context for managing resources and execution. * @param dst The destination tensor. Its src[0] is treated as the input tensor. */ @@ -702,10 +736,30 @@ template aclTensor* acl_dst = ggml_cann_create_tensor(dst); unary_op(ctx, acl_src, acl_dst); + ACL_CHECK(aclDestroyTensor(acl_src)); ACL_CHECK(aclDestroyTensor(acl_dst)); } +/** + * @brief Applies a unary operation to a ggml tensor using the CANN backend. + * + * @details This function performs a unary operation on the input tensor using + * a user-provided lambda or callable object `unary_op`, which accepts the CANN + * context and two ACL tensors (source and destination). Internally, this function + * creates ACL representations of the ggml tensors and invokes the unary operation. + * The result is stored in the destination tensor `dst`. This utility abstracts the + * common boilerplate of tensor conversion and cleanup when implementing unary ops. + * + * @param unary_op A callable that performs the unary operation using CANN APIs. + * @param ctx The CANN context used for operations. + * @param dst The destination tensor where the result will be stored. + * The source tensor is retrieved from `dst->src[0]`. + */ +void ggml_cann_unary_op( + std::function unary_op, + ggml_backend_cann_context& ctx, ggml_tensor* dst); + /** * @brief Helper macro to invoke a unary ACL operation using ggml_cann_unary_op. * @@ -725,11 +779,12 @@ template */ #define GGML_CANN_CALL_UNARY_OP(OP_NAME) \ do { \ - auto lambda = [](auto ctx, auto acl_src, auto acl_dst) { \ + auto lambda = [](ggml_backend_cann_context& ctx, \ + aclTensor* acl_src, \ + aclTensor* acl_dst) { \ GGML_CANN_CALL_ACLNN_OP(OP_NAME, acl_src, acl_dst); \ }; \ - ggml_cann_unary_op(ctx, dst); \ + ggml_cann_unary_op(lambda, ctx, dst); \ } \ while (0) - #endif // CANN_ACLNN_OPS diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index 326f9d298..f9187ba81 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -1330,12 +1330,13 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx, GGML_CANN_CALL_UNARY_OP(Silu); break; case GGML_UNARY_OP_GELU_QUICK: { - auto lambda = [](auto ctx, auto acl_src, auto acl_dst) { - GGML_CANN_CALL_ACLNN_OP(GeluV2, acl_src, 0, acl_dst); - }; - ggml_cann_unary_op(ctx, dst); - } - break; + auto lambda = [](ggml_backend_cann_context& ctx, + aclTensor* acl_src, + aclTensor* acl_dst) { + GGML_CANN_CALL_ACLNN_OP(GeluV2, acl_src, 0, acl_dst); + }; + ggml_cann_unary_op(lambda, ctx, dst); + } break; case GGML_UNARY_OP_TANH: GGML_CANN_CALL_UNARY_OP(Tanh); break; @@ -1354,6 +1355,9 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx, case GGML_UNARY_OP_EXP: GGML_CANN_CALL_UNARY_OP(Exp); break; + case GGML_UNARY_OP_ELU: + ggml_cann_elu(ctx, dst); + break; default: return false; } @@ -1448,7 +1452,10 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx, break; case GGML_OP_SIN: ggml_cann_unary_op(ctx, dst); - break; + break; + case GGML_OP_CONV_TRANSPOSE_1D: + ggml_cann_conv_transpose_1d(ctx, dst); + break; default: return false; } @@ -1710,6 +1717,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_TANH: case GGML_UNARY_OP_EXP: + case GGML_UNARY_OP_ELU: return true; default: return false; @@ -1842,6 +1850,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, case GGML_OP_ARGMAX: case GGML_OP_COS: case GGML_OP_SIN: + case GGML_OP_CONV_TRANSPOSE_1D: return true; default: return false; From 47277d6d1d0d515cff34292a1a78a0d1b7252350 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 9 Apr 2025 10:54:42 +0300 Subject: [PATCH 09/20] readme : add rpc backend (#12842) --- README.md | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/README.md b/README.md index 95a05e6ed..e56042f1c 100644 --- a/README.md +++ b/README.md @@ -9,13 +9,6 @@ Inference of Meta's [LLaMA](https://arxiv.org/abs/2302.13971) model (and others) in pure C/C++ -> [!IMPORTANT] -> New `llama.cpp` package location: [ggml-org/llama.cpp](https://github.com/ggml-org/llama.cpp/pkgs/container/llama.cpp) -> -> Update your container URLs to: `ghcr.io/ggml-org/llama.cpp` -> -> More info: https://github.com/ggml-org/llama.cpp/discussions/11801 - ## Recent API changes - [Changelog for `libllama` API](https://github.com/ggml-org/llama.cpp/issues/9289) @@ -247,6 +240,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo | [Vulkan](docs/build.md#vulkan) | GPU | | [CANN](docs/build.md#cann) | Ascend NPU | | [OpenCL](docs/backend/OPENCL.md) | Adreno GPU | +| [RPC](https://github.com/ggml-org/llama.cpp/tree/master/examples/rpc) | All | ## Building the project From 65a69e6e1b6d55cd5f78f8bcdfaba8a8c59a8d96 Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Wed, 9 Apr 2025 10:09:53 +0200 Subject: [PATCH 10/20] clip : do not print ftype (#12832) --- examples/llava/clip.cpp | 5 ----- 1 file changed, 5 deletions(-) diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp index 44428cc95..4f21e836a 100644 --- a/examples/llava/clip.cpp +++ b/examples/llava/clip.cpp @@ -331,7 +331,6 @@ struct clip_ctx { float image_std[3]; bool use_gelu = false; bool use_silu = false; - int32_t ftype = 1; struct gguf_context * ctx_gguf = nullptr; struct ggml_context * ctx_data = nullptr; @@ -1142,9 +1141,6 @@ struct clip_model_loader { // print gguf info { - int ftype = -1; - get_u32(KEY_FTYPE, ftype, false); - const std::string ftype_str = ggml_type_name(static_cast(ftype)); std::string name; get_string(KEY_NAME, name, false); std::string description; @@ -1155,7 +1151,6 @@ struct clip_model_loader { LOG_INF("%s: alignment: %zu\n", __func__, gguf_get_alignment(ctx_gguf.get())); LOG_INF("%s: n_tensors: %d\n", __func__, n_tensors); LOG_INF("%s: n_kv: %d\n", __func__, (int)gguf_get_n_kv(ctx_gguf.get())); - LOG_INF("%s: ftype: %s\n", __func__, ftype_str.c_str()); LOG_INF("\n"); } From 381603a77504ad1788965f694094540c1bed9ea2 Mon Sep 17 00:00:00 2001 From: Plamen Minev Date: Wed, 9 Apr 2025 11:11:11 +0300 Subject: [PATCH 11/20] ci: detach common from the library (#12827) * fix: detach common from the library * fix: building chat test template --- examples/server/utils.hpp | 2 +- src/CMakeLists.txt | 2 +- tests/test-chat-template.cpp | 14 +++++++++----- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 55cf3230d..aba2f27f9 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -3,7 +3,7 @@ #include "common.h" #include "log.h" #include "llama.h" -#include "common/base64.hpp" +#include "base64.hpp" // increase max payload length to allow use of larger context size #define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 1048576 diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index b340dae5b..9f7ab13f1 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -32,7 +32,7 @@ add_library(llama unicode.h ) -target_include_directories(llama PUBLIC . ../include ../common) +target_include_directories(llama PUBLIC . ../include) target_compile_features (llama PUBLIC cxx_std_17) # don't bump target_link_libraries(llama PUBLIC ggml) diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp index a9627df68..be1a64006 100644 --- a/tests/test-chat-template.cpp +++ b/tests/test-chat-template.cpp @@ -19,6 +19,8 @@ static std::string normalize_newlines(const std::string & s) { #endif } +#define U8C(x) (const char*)(u8##x) + static common_chat_msg simple_msg(const std::string & role, const std::string & content) { common_chat_msg msg; msg.role = role; @@ -35,6 +37,8 @@ int main(void) { {"assistant", " I am an assistant "}, {"user", "Another question"}, }; + + // std::string wrong = /* .template_str= */ u8"[gMASK]{% for item in messages %}{% if item['tools'] is defined %}<|system|>\n你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,你的任务是针对用户的问题和要求提供适当的答复和支持。\n\n# 可用工具{% set tools = item['tools'] %}{% for tool in tools %}{% if tool['type'] == 'function' %}\n\n## {{ tool['function']['name'] }}\n\n{{ tool['function'] | tojson(indent=4) }}\n......{% endif %}{% endfor %}{% endif %}{% if item['content'] %}<|{{ item['role'] }}|>{{ item['metadata'] }}\n{{ item['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}"; struct TestCase { std::string name; std::string template_str; @@ -177,7 +181,7 @@ int main(void) { }, { /* .name= */ "ChatGLM4", - /* .template_str= */ u8"[gMASK]{% for item in messages %}{% if item['tools'] is defined %}<|system|>\n你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,你的任务是针对用户的问题和要求提供适当的答复和支持。\n\n# 可用工具{% set tools = item['tools'] %}{% for tool in tools %}{% if tool['type'] == 'function' %}\n\n## {{ tool['function']['name'] }}\n\n{{ tool['function'] | tojson(indent=4) }}\n......{% endif %}{% endfor %}{% endif %}{% if item['content'] %}<|{{ item['role'] }}|>{{ item['metadata'] }}\n{{ item['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}", + /* .template_str= */ U8C("[gMASK]{% for item in messages %}{% if item['tools'] is defined %}<|system|>\n你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,你的任务是针对用户的问题和要求提供适当的答复和支持。\n\n# 可用工具{% set tools = item['tools'] %}{% for tool in tools %}{% if tool['type'] == 'function' %}\n\n## {{ tool['function']['name'] }}\n\n{{ tool['function'] | tojson(indent=4) }}\n......{% endif %}{% endfor %}{% endif %}{% if item['content'] %}<|{{ item['role'] }}|>{{ item['metadata'] }}\n{{ item['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}"), /* .expected_output= */ "[gMASK]<|system|>\nYou are a helpful assistant<|user|>\nHello<|assistant|>\nHi there<|user|>\nWho are you<|assistant|>\n I am an assistant <|user|>\nAnother question<|assistant|>", /* .expected_output_jinja= */ "", /* .bos_token= */ "", @@ -193,8 +197,8 @@ int main(void) { }, { /* .name= */ "MiniCPM-3B-OpenHermes-2.5-v2-GGUF", - /* .template_str= */ u8"{% for message in messages %}{% if message['role'] == 'user' %}{{'<用户>' + message['content'].strip() + ''}}{% else %}{{message['content'].strip()}}{% endif %}{% endfor %}", - /* .expected_output= */ u8"You are a helpful assistant<用户>HelloHi there<用户>Who are youI am an assistant<用户>Another question", + /* .template_str= */ U8C("{% for message in messages %}{% if message['role'] == 'user' %}{{'<用户>' + message['content'].strip() + ''}}{% else %}{{message['content'].strip()}}{% endif %}{% endfor %}"), + /* .expected_output= */ U8C("You are a helpful assistant<用户>HelloHi there<用户>Who are youI am an assistant<用户>Another question"), /* .expected_output_jinja= */ "", /* .bos_token= */ "", /* .eos_token= */ "", @@ -202,7 +206,7 @@ int main(void) { { /* .name= */ "DeepSeek-V2", /* .template_str= */ "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ 'User: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}", - /* .expected_output= */ u8"You are a helpful assistant\n\nUser: Hello\n\nAssistant: Hi there<|end▁of▁sentence|>User: Who are you\n\nAssistant: I am an assistant <|end▁of▁sentence|>User: Another question\n\nAssistant:", + /* .expected_output= */ U8C("You are a helpful assistant\n\nUser: Hello\n\nAssistant: Hi there<|end▁of▁sentence|>User: Who are you\n\nAssistant: I am an assistant <|end▁of▁sentence|>User: Another question\n\nAssistant:"), /* .expected_output_jinja= */ "", /* .bos_token= */ "", /* .eos_token= */ "<|end▁of▁sentence|>", @@ -256,7 +260,7 @@ int main(void) { }, { /* .name= */ "Infinigence/Megrez-3B-Instruct", - /* .template_str= */ u8"{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|role_start|>system<|role_end|>你是Megrez-3B-Instruct,将针对用户的问题给出详细的、积极的回答。<|turn_end|>' }}{% endif %}{{ '<|role_start|>' + message['role'] + '<|role_end|>' + message['content'] + '<|turn_end|>' }}{% endfor %}{% if add_generation_prompt %}{{ '<|role_start|>assistant<|role_end|>' }}{% endif %}", + /* .template_str= */ U8C("{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|role_start|>system<|role_end|>你是Megrez-3B-Instruct,将针对用户的问题给出详细的、积极的回答。<|turn_end|>' }}{% endif %}{{ '<|role_start|>' + message['role'] + '<|role_end|>' + message['content'] + '<|turn_end|>' }}{% endfor %}{% if add_generation_prompt %}{{ '<|role_start|>assistant<|role_end|>' }}{% endif %}"), /* .expected_output= */ "<|role_start|>system<|role_end|>You are a helpful assistant<|turn_end|><|role_start|>user<|role_end|>Hello<|turn_end|><|role_start|>assistant<|role_end|>Hi there<|turn_end|><|role_start|>user<|role_end|>Who are you<|turn_end|><|role_start|>assistant<|role_end|> I am an assistant <|turn_end|><|role_start|>user<|role_end|>Another question<|turn_end|><|role_start|>assistant<|role_end|>", /* .expected_output_jinja= */ "", /* .bos_token= */ "", From 8ed71242f464dc0a3fb3cffcfe064e55bdec72f9 Mon Sep 17 00:00:00 2001 From: Romain Biessy Date: Wed, 9 Apr 2025 11:22:04 +0200 Subject: [PATCH 12/20] sycl: update documentation to use -no-cnv (#12845) --- docs/backend/SYCL.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/backend/SYCL.md b/docs/backend/SYCL.md index cb29075b1..20aefec2f 100644 --- a/docs/backend/SYCL.md +++ b/docs/backend/SYCL.md @@ -425,13 +425,13 @@ Examples: - Use device 0: ```sh -ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m models/llama-2-7b.Q4_0.gguf -p "Building a website can be done in 10 simple steps:" -n 400 -e -ngl 33 -sm none -mg 0 +ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -no-cnv -m models/llama-2-7b.Q4_0.gguf -p "Building a website can be done in 10 simple steps:" -n 400 -e -ngl 33 -sm none -mg 0 ``` - Use multiple devices: ```sh -ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m models/llama-2-7b.Q4_0.gguf -p "Building a website can be done in 10 simple steps:" -n 400 -e -ngl 33 -sm layer +ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -no-cnv -m models/llama-2-7b.Q4_0.gguf -p "Building a website can be done in 10 simple steps:" -n 400 -e -ngl 33 -sm layer ``` *Notes:* @@ -697,13 +697,13 @@ Examples: - Use device 0: ``` -build\bin\llama-cli.exe -m models\llama-2-7b.Q4_0.gguf -p "Building a website can be done in 10 simple steps:\nStep 1:" -n 400 -e -ngl 33 -s 0 -sm none -mg 0 +build\bin\llama-cli.exe -no-cnv -m models\llama-2-7b.Q4_0.gguf -p "Building a website can be done in 10 simple steps:\nStep 1:" -n 400 -e -ngl 33 -s 0 -sm none -mg 0 ``` - Use multiple devices: ``` -build\bin\llama-cli.exe -m models\llama-2-7b.Q4_0.gguf -p "Building a website can be done in 10 simple steps:\nStep 1:" -n 400 -e -ngl 33 -s 0 -sm layer +build\bin\llama-cli.exe -no-cnv -m models\llama-2-7b.Q4_0.gguf -p "Building a website can be done in 10 simple steps:\nStep 1:" -n 400 -e -ngl 33 -s 0 -sm layer ``` From d9a63b2f2e91cdcb0eda211b7f49fadcdea0f664 Mon Sep 17 00:00:00 2001 From: R0CKSTAR Date: Wed, 9 Apr 2025 17:22:30 +0800 Subject: [PATCH 13/20] musa: enable freediskspace for docker image build (#12839) Signed-off-by: Xiaodong Ye --- .github/workflows/docker.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index c81d21fcd..9eba3f6a4 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -38,7 +38,7 @@ jobs: # Multi-stage build - { tag: "cpu", dockerfile: ".devops/cpu.Dockerfile", platforms: "linux/amd64,linux/arm64", full: true, light: true, server: true, freediskspace: false} - { tag: "cuda", dockerfile: ".devops/cuda.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, freediskspace: false} - - { tag: "musa", dockerfile: ".devops/musa.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, freediskspace: false} + - { tag: "musa", dockerfile: ".devops/musa.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, freediskspace: true} - { tag: "intel", dockerfile: ".devops/intel.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, freediskspace: false} - { tag: "vulkan", dockerfile: ".devops/vulkan.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, freediskspace: false} # Note: the rocm images are failing due to a compiler error and are disabled until this is fixed to allow the workflow to complete From d3bd7193ba66c15963fd1c59448f22019a8caf6e Mon Sep 17 00:00:00 2001 From: Bo Zheng <368586905@qq.com> Date: Wed, 9 Apr 2025 17:47:36 +0800 Subject: [PATCH 14/20] llama : Support Qwen3 and Qwen3MoE (#12828) * add qwen3 & qwen3moe support. * fix --------- Co-authored-by: bozheng-hit --- convert_hf_to_gguf.py | 10 ++ gguf-py/gguf/constants.py | 38 +++++ src/llama-arch.cpp | 41 +++++ src/llama-arch.h | 2 + src/llama-model.cpp | 350 ++++++++++++++++++++++++++++++++++++++ 5 files changed, 441 insertions(+) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 954990020..656dc9877 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -2459,6 +2459,16 @@ class Qwen2MoeModel(Model): raise ValueError(f"Unprocessed experts: {experts}") +@Model.register("Qwen3ForCausalLM") +class Qwen3Model(Qwen2Model): + model_arch = gguf.MODEL_ARCH.QWEN3 + + +@Model.register("Qwen3MoeForCausalLM") +class Qwen3MoeModel(Qwen2MoeModel): + model_arch = gguf.MODEL_ARCH.QWEN3MOE + + @Model.register("GPT2LMHeadModel") class GPT2Model(Model): model_arch = gguf.MODEL_ARCH.GPT2 diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index d4f4e1179..0410654dd 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -248,6 +248,8 @@ class MODEL_ARCH(IntEnum): QWEN2 = auto() QWEN2MOE = auto() QWEN2VL = auto() + QWEN3 = auto() + QWEN3MOE = auto() PHI2 = auto() PHI3 = auto() PHIMOE = auto() @@ -453,6 +455,8 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.QWEN2: "qwen2", MODEL_ARCH.QWEN2MOE: "qwen2moe", MODEL_ARCH.QWEN2VL: "qwen2vl", + MODEL_ARCH.QWEN3: "qwen3", + MODEL_ARCH.QWEN3MOE: "qwen3moe", MODEL_ARCH.PHI2: "phi2", MODEL_ARCH.PHI3: "phi3", MODEL_ARCH.PHIMOE: "phimoe", @@ -953,6 +957,40 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_DOWN_SHEXP, MODEL_TENSOR.FFN_UP_SHEXP, ], + MODEL_ARCH.QWEN3: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], + MODEL_ARCH.QWEN3MOE: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + ], MODEL_ARCH.PLAMO: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index ac997b963..264f8c5b9 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -26,6 +26,8 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_QWEN2, "qwen2" }, { LLM_ARCH_QWEN2MOE, "qwen2moe" }, { LLM_ARCH_QWEN2VL, "qwen2vl" }, + { LLM_ARCH_QWEN3, "qwen3" }, + { LLM_ARCH_QWEN3MOE, "qwen3moe" }, { LLM_ARCH_PHI2, "phi2" }, { LLM_ARCH_PHI3, "phi3" }, { LLM_ARCH_PHIMOE, "phimoe" }, @@ -595,6 +597,45 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, }, }, + { + LLM_ARCH_QWEN3, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_QWEN3MOE, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + }, + }, { LLM_ARCH_PHI2, { diff --git a/src/llama-arch.h b/src/llama-arch.h index 42e4a3ef9..201935281 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -30,6 +30,8 @@ enum llm_arch { LLM_ARCH_QWEN2, LLM_ARCH_QWEN2MOE, LLM_ARCH_QWEN2VL, + LLM_ARCH_QWEN3, + LLM_ARCH_QWEN3MOE, LLM_ARCH_PHI2, LLM_ARCH_PHI3, LLM_ARCH_PHIMOE, diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 4546e9cf9..9e4166a71 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -787,6 +787,22 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_QWEN3: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_QWEN3MOE: + { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_PHI2: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); @@ -2360,6 +2376,77 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp}, 0); } } break; + case LLM_ARCH_QWEN3: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } break; + case LLM_ARCH_QWEN3MOE: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0 for QWEN3MOE"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0 for QWEN3MOE"); + } + + // MoE branch + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + } + } break; case LLM_ARCH_PHI2: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -4168,6 +4255,10 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); } + if (arch == LLM_ARCH_QWEN3MOE) { + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + } + if (arch == LLM_ARCH_MINICPM || arch == LLM_ARCH_GRANITE || arch == LLM_ARCH_GRANITE_MOE) { LLAMA_LOG_INFO("%s: f_embedding_scale = %f\n", __func__, hparams.f_embedding_scale); LLAMA_LOG_INFO("%s: f_residual_scale = %f\n", __func__, hparams.f_residual_scale); @@ -6582,6 +6673,255 @@ struct llm_build_qwen2moe : public llm_graph_context { } }; +struct llm_build_qwen3 : public llm_graph_context { + llm_build_qwen3(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_qwen3moe : public llm_graph_context { + llm_build_qwen3moe(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self_attention + { + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // MoE branch + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + ggml_tensor * moe_out = + build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + false, 0.0, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + il); + cb(moe_out, "ffn_moe_out", il); + cur = moe_out; + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + struct llm_build_phi2 : public llm_graph_context { llm_build_phi2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; @@ -12282,6 +12622,14 @@ llm_graph_result_ptr llama_model::build_graph( { llm = std::make_unique(*this, params, gf); } break; + case LLM_ARCH_QWEN3: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_QWEN3MOE: + { + llm = std::make_unique(*this, params, gf); + } break; case LLM_ARCH_PHI2: { llm = std::make_unique(*this, params, gf); @@ -12601,6 +12949,8 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_QWEN: case LLM_ARCH_QWEN2: case LLM_ARCH_QWEN2MOE: + case LLM_ARCH_QWEN3: + case LLM_ARCH_QWEN3MOE: case LLM_ARCH_OLMO2: case LLM_ARCH_OLMOE: case LLM_ARCH_PHI2: From 2391506ace6abb56186def40c7107fdfa694ed55 Mon Sep 17 00:00:00 2001 From: Piotr Kubaj Date: Wed, 9 Apr 2025 23:00:25 +0000 Subject: [PATCH 15/20] ggml-impl.h: fix build on POWER9 (#12855) error: ISO C++17 does not allow 'register' storage class specifier --- ggml/src/ggml-impl.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h index 606175fb9..caa6b9dba 100644 --- a/ggml/src/ggml-impl.h +++ b/ggml/src/ggml-impl.h @@ -355,8 +355,8 @@ GGML_API void ggml_aligned_free(void * ptr, size_t size); #define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x) static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) { - register float f; - register double d; + float f; + double d; __asm__( "mtfprd %0,%2\n" "xscvhpdp %0,%0\n" @@ -368,8 +368,8 @@ GGML_API void ggml_aligned_free(void * ptr, size_t size); } static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) { - register double d; - register ggml_fp16_t r; + double d; + ggml_fp16_t r; __asm__( /* xscvdphp can work on double or single precision */ "xscvdphp %0,%2\n" "mffprd %1,%0\n" : From 31f7803bc4e7c0dcc279ee04c2ecfb76b2afdd3e Mon Sep 17 00:00:00 2001 From: Piotr Kubaj Date: Wed, 9 Apr 2025 23:00:34 +0000 Subject: [PATCH 16/20] ggml-cpu-impl.h: do not redefine bool on POWER9 (#12856) error: unknown type name '_Bool' --- ggml/src/ggml-cpu/ggml-cpu-impl.h | 2 -- 1 file changed, 2 deletions(-) diff --git a/ggml/src/ggml-cpu/ggml-cpu-impl.h b/ggml/src/ggml-cpu/ggml-cpu-impl.h index 8eed9bb57..e4af07635 100644 --- a/ggml/src/ggml-cpu/ggml-cpu-impl.h +++ b/ggml/src/ggml-cpu/ggml-cpu-impl.h @@ -323,8 +323,6 @@ inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b) #else #ifdef __POWER9_VECTOR__ #include -#undef bool -#define bool _Bool #else #if defined(_MSC_VER) || defined(__MINGW32__) #include From b0091ecc1e5c0f689be856fade3803a534f35c9f Mon Sep 17 00:00:00 2001 From: Rudi Servo Date: Wed, 9 Apr 2025 23:17:12 +0000 Subject: [PATCH 17/20] docker : added all CPU to GPU images (#12749) --- .devops/cuda.Dockerfile | 2 +- .devops/intel.Dockerfile | 2 +- .devops/musa.Dockerfile | 2 +- .devops/rocm.Dockerfile | 6 +++--- .devops/vulkan.Dockerfile | 2 +- 5 files changed, 7 insertions(+), 7 deletions(-) diff --git a/.devops/cuda.Dockerfile b/.devops/cuda.Dockerfile index a196111e6..8ae57d2e2 100644 --- a/.devops/cuda.Dockerfile +++ b/.devops/cuda.Dockerfile @@ -21,7 +21,7 @@ COPY . . RUN if [ "${CUDA_DOCKER_ARCH}" != "default" ]; then \ export CMAKE_ARGS="-DCMAKE_CUDA_ARCHITECTURES=${CUDA_DOCKER_ARCH}"; \ fi && \ - cmake -B build -DGGML_NATIVE=OFF -DGGML_CUDA=ON ${CMAKE_ARGS} -DCMAKE_EXE_LINKER_FLAGS=-Wl,--allow-shlib-undefined . && \ + cmake -B build -DGGML_NATIVE=OFF -DGGML_CUDA=ON -DLLAMA_CURL=ON -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON ${CMAKE_ARGS} -DCMAKE_EXE_LINKER_FLAGS=-Wl,--allow-shlib-undefined . && \ cmake --build build --config Release -j$(nproc) RUN mkdir -p /app/lib && \ diff --git a/.devops/intel.Dockerfile b/.devops/intel.Dockerfile index e2b381766..091e1dc5d 100644 --- a/.devops/intel.Dockerfile +++ b/.devops/intel.Dockerfile @@ -17,7 +17,7 @@ RUN if [ "${GGML_SYCL_F16}" = "ON" ]; then \ && export OPT_SYCL_F16="-DGGML_SYCL_F16=ON"; \ fi && \ echo "Building with dynamic libs" && \ - cmake -B build -DGGML_NATIVE=OFF -DGGML_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx ${OPT_SYCL_F16} && \ + cmake -B build -DGGML_NATIVE=OFF -DGGML_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DLLAMA_CURL=ON -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON ${OPT_SYCL_F16} && \ cmake --build build --config Release -j$(nproc) RUN mkdir -p /app/lib && \ diff --git a/.devops/musa.Dockerfile b/.devops/musa.Dockerfile index e8297c694..261a2823a 100644 --- a/.devops/musa.Dockerfile +++ b/.devops/musa.Dockerfile @@ -35,7 +35,7 @@ COPY . . RUN if [ "${MUSA_DOCKER_ARCH}" != "default" ]; then \ export CMAKE_ARGS="-DMUSA_ARCHITECTURES=${MUSA_DOCKER_ARCH}"; \ fi && \ - cmake -B build -DGGML_NATIVE=OFF -DGGML_MUSA=ON ${CMAKE_ARGS} -DCMAKE_EXE_LINKER_FLAGS=-Wl,--allow-shlib-undefined . && \ + cmake -B build -DGGML_NATIVE=OFF -DGGML_MUSA=ON -DLLAMA_CURL=ON -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON ${CMAKE_ARGS} -DCMAKE_EXE_LINKER_FLAGS=-Wl,--allow-shlib-undefined . && \ cmake --build build --config Release -j$(nproc) RUN mkdir -p /app/lib && \ diff --git a/.devops/rocm.Dockerfile b/.devops/rocm.Dockerfile index 66687a25b..a1b34723a 100644 --- a/.devops/rocm.Dockerfile +++ b/.devops/rocm.Dockerfile @@ -17,8 +17,8 @@ FROM ${BASE_ROCM_DEV_CONTAINER} AS build # gfx906 is deprecated #check https://rocm.docs.amd.com/projects/install-on-linux/en/docs-6.2.4/reference/system-requirements.html -#ARG ROCM_DOCKER_ARCH='gfx803,gfx900,gfx906,gfx908,gfx90a,gfx942,gfx1010,gfx1030,gfx1032,gfx1100,gfx1101,gfx1102' -ARG ROCM_DOCKER_ARCH=gfx1100 +ARG ROCM_DOCKER_ARCH='gfx803,gfx900,gfx906,gfx908,gfx90a,gfx942,gfx1010,gfx1030,gfx1032,gfx1100,gfx1101,gfx1102' +#ARG ROCM_DOCKER_ARCH=gfx1100 # Set nvcc architectured ENV AMDGPU_TARGETS=${ROCM_DOCKER_ARCH} @@ -40,7 +40,7 @@ WORKDIR /app COPY . . RUN HIPCXX="$(hipconfig -l)/clang" HIP_PATH="$(hipconfig -R)" \ - cmake -S . -B build -DGGML_HIP=ON -DAMDGPU_TARGETS=$ROCM_DOCKER_ARCH -DCMAKE_BUILD_TYPE=Release \ + cmake -S . -B build -DGGML_HIP=ON -DAMDGPU_TARGETS=$ROCM_DOCKER_ARCH -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON -DCMAKE_BUILD_TYPE=Release -DLLAMA_CURL=ON \ && cmake --build build --config Release -j$(nproc) RUN mkdir -p /app/lib \ diff --git a/.devops/vulkan.Dockerfile b/.devops/vulkan.Dockerfile index 9064f3838..f8f3072e9 100644 --- a/.devops/vulkan.Dockerfile +++ b/.devops/vulkan.Dockerfile @@ -16,7 +16,7 @@ WORKDIR /app COPY . . -RUN cmake -B build -DGGML_NATIVE=OFF -DGGML_VULKAN=1 -DLLAMA_CURL=1 && \ +RUN cmake -B build -DGGML_NATIVE=OFF -DGGML_VULKAN=1 -DLLAMA_CURL=1 -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON && \ cmake --build build --config Release -j$(nproc) RUN mkdir -p /app/lib && \ From 11d07e1e69138b46375e9267b31acd58e3813577 Mon Sep 17 00:00:00 2001 From: Prajwal B Mehendarkar Date: Thu, 10 Apr 2025 04:48:01 +0530 Subject: [PATCH 18/20] Fixes #12823 (#12830) * Including limits file on AIX * Fixes #12823 --- ggml/src/ggml-cpu/simd-mappings.h | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-cpu/simd-mappings.h b/ggml/src/ggml-cpu/simd-mappings.h index e0b5fc38d..d7db9209f 100644 --- a/ggml/src/ggml-cpu/simd-mappings.h +++ b/ggml/src/ggml-cpu/simd-mappings.h @@ -392,7 +392,11 @@ static inline void __avx_f32cx8_store(ggml_fp16_t *x, __m256 y) { #define GGML_F16_VEC_LOAD(p, i) (i & 0x1) ? \ vec_extract_fp32_from_shorth(vec_xl(0, p - GGML_F16_EPR)) : \ vec_extract_fp32_from_shortl(vec_xl(0, p)) -#define GGML_ENDIAN_BYTE(i) ((unsigned char *)&(uint16_t){1})[i] +static inline unsigned char ggml_endian_byte(int i) { + uint16_t tmp_val = 1; + return ((unsigned char *)&tmp_val)[i]; +} +#define GGML_ENDIAN_BYTE(i) ggml_endian_byte(i) #define GGML_F16_VEC_STORE(p, r, i) \ if (i & 0x1) \ vec_xst(vec_pack_to_short_fp32(r[i - GGML_ENDIAN_BYTE(1)], \ From fe5b78c89670b2f37ecb216306bed3e677b49d9f Mon Sep 17 00:00:00 2001 From: Chenguang Li <757486878@qq.com> Date: Thu, 10 Apr 2025 08:51:52 +0800 Subject: [PATCH 19/20] CANN: Support more ops (#12841) * [CANN]Support Opt LOG && MEAN && PAD_REFLECT_1D * [CANN]Support COUNT_EQUAL && STEP && SGN * [CANN]codestyle adjustment * [CANN]codestyle adjustment --------- Signed-off-by: noemotiovon --- ggml/src/ggml-cann/acl_tensor.cpp | 2 + ggml/src/ggml-cann/aclnn_ops.cpp | 84 +++++++++++++++++++++++++++++++ ggml/src/ggml-cann/aclnn_ops.h | 63 +++++++++++++++++++++++ ggml/src/ggml-cann/ggml-cann.cpp | 24 +++++++++ 4 files changed, 173 insertions(+) diff --git a/ggml/src/ggml-cann/acl_tensor.cpp b/ggml/src/ggml-cann/acl_tensor.cpp index 9b6553c50..f5462c5a1 100644 --- a/ggml/src/ggml-cann/acl_tensor.cpp +++ b/ggml/src/ggml-cann/acl_tensor.cpp @@ -41,6 +41,8 @@ aclDataType ggml_cann_type_mapping(ggml_type type) { return ACL_INT4; case GGML_TYPE_Q8_0: return ACL_INT8; + case GGML_TYPE_I64: + return ACL_INT64; default: return ACL_DT_UNDEFINED; } diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index 25b2599c7..37d411797 100644 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -59,6 +59,11 @@ #include #include #include +#include +#include +#include +#include +#include #include #include @@ -2598,6 +2603,7 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) { aclTensor* acl_dst = ggml_cann_create_tensor(dst, dst->ne, dst->nb, 3); GGML_CANN_CALL_ACLNN_OP(ArgMax, acl_src, 3, false, acl_dst); + ACL_CHECK(aclDestroyTensor(acl_src)); ACL_CHECK(aclDestroyTensor(acl_dst)); } @@ -2629,6 +2635,9 @@ void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* ds ACL_CHECK(aclDestroyTensor(acl_weight)); ACL_CHECK(aclDestroyTensor(acl_dst)); + ACL_CHECK(aclDestroyIntArray(stride)); + ACL_CHECK(aclDestroyIntArray(padding)); + ACL_CHECK(aclDestroyIntArray(dilation)); } void ggml_cann_elu(ggml_backend_cann_context& ctx, ggml_tensor* dst){ @@ -2646,4 +2655,79 @@ void ggml_cann_elu(ggml_backend_cann_context& ctx, ggml_tensor* dst){ ACL_CHECK(aclDestroyTensor(acl_input)); ACL_CHECK(aclDestroyTensor(acl_dst)); + ACL_CHECK(aclDestroyScalar(alpha)); +} + +void ggml_cann_mean(ggml_backend_cann_context& ctx, ggml_tensor* dst){ + ggml_tensor * src0 = dst->src[0]; + + aclTensor* acl_src = ggml_cann_create_tensor(src0); + aclTensor* acl_dst = ggml_cann_create_tensor(dst); + + int64_t reduceDimValue[] = {3}; + aclIntArray* reduceDim = aclCreateIntArray(reduceDimValue, 1); + bool keepDim = true; + + GGML_CANN_CALL_ACLNN_OP(Mean, acl_src, reduceDim, keepDim, ACL_FLOAT, acl_dst); + + ACL_CHECK(aclDestroyTensor(acl_src)); + ACL_CHECK(aclDestroyTensor(acl_dst)); + ACL_CHECK(aclDestroyIntArray(reduceDim)); +} + +void ggml_cann_pad_reflect_1d(ggml_backend_cann_context& ctx, ggml_tensor* dst){ + ggml_tensor * src0 = dst->src[0]; + int32_t *opts = (int32_t *) dst->op_params; + int64_t paddingsArray[2] = {opts[0], opts[1]}; + aclIntArray* paddings = aclCreateIntArray(paddingsArray, 2); + + for (int64_t i = 0; i < src0->ne[3]; i++) { + aclTensor* acl_src = ggml_cann_create_tensor( + (char*)src0->data + i * src0->ne[3], + ggml_cann_type_mapping(src0->type), ggml_element_size(src0), + src0->ne, src0->nb, 3); + + aclTensor* acl_dst = ggml_cann_create_tensor( + (char*)dst->data + i * src0->ne[3], + ggml_cann_type_mapping(dst->type), ggml_element_size(dst), + dst->ne, dst->nb, 3); + + GGML_CANN_CALL_ACLNN_OP(ReflectionPad1d, acl_src, paddings, acl_dst); + + ACL_CHECK(aclDestroyTensor(acl_src)); + ACL_CHECK(aclDestroyTensor(acl_dst)); + } + ACL_CHECK(aclDestroyIntArray(paddings)); +} + +void ggml_cann_count_equal(ggml_backend_cann_context& ctx, ggml_tensor* dst){ + ggml_tensor * src0 = dst->src[0]; + ggml_tensor * src1 = dst->src[1]; + + aclTensor* acl_self = ggml_cann_create_tensor(src0); + aclTensor* acl_other = ggml_cann_create_tensor(src1); + + GGML_CANN_CALL_ACLNN_OP(InplaceEqTensor, acl_self, acl_other); + + ggml_cann_sum(ctx, dst); + + ACL_CHECK(aclDestroyTensor(acl_self)); + ACL_CHECK(aclDestroyTensor(acl_other)); +} + +void ggml_cann_step(ggml_backend_cann_context& ctx, ggml_tensor* dst){ + ggml_tensor * src0 = dst->src[0]; + + aclTensor* acl_src = ggml_cann_create_tensor(src0); + aclTensor* acl_dst = ggml_cann_create_tensor(dst); + + float alphaValue = 0.0f; + aclScalar* alpha = nullptr; + alpha = aclCreateScalar(&alphaValue, aclDataType::ACL_FLOAT); + + GGML_CANN_CALL_ACLNN_OP(GtScalar, acl_src, alpha, acl_dst); + + ACL_CHECK(aclDestroyTensor(acl_src)); + ACL_CHECK(aclDestroyTensor(acl_dst)); + ACL_CHECK(aclDestroyScalar(alpha)); } diff --git a/ggml/src/ggml-cann/aclnn_ops.h b/ggml/src/ggml-cann/aclnn_ops.h index aadf013de..b2d1b3c36 100644 --- a/ggml/src/ggml-cann/aclnn_ops.h +++ b/ggml/src/ggml-cann/aclnn_ops.h @@ -42,6 +42,8 @@ #include #include #include +#include +#include #include "acl_tensor.h" #include "common.h" @@ -650,6 +652,67 @@ void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* ds */ void ggml_cann_elu(ggml_backend_cann_context& ctx, ggml_tensor* dst); +/** + * @brief Computes the mean of a ggml tensor element-wise using the CANN backend. + * + * @details This function calculates the element-wise mean of the input tensor. + * The result is written to the destination tensor `dst`. + * The mean is computed by averaging the values across the entire tensor. + * + * This operation is optimized using the CANN backend for high-performance inference or training. + * + * @param ctx The CANN context used for operations. + * @param dst The destination tensor where the mean result will be stored. + * dst->op is expected to be `GGML_OP_MEAN`. + */ +void ggml_cann_mean(ggml_backend_cann_context& ctx, ggml_tensor* dst); + +/** + * @brief Applies 1D reflect padding to a ggml tensor using the CANN backend. + * + * @details This function performs 1D reflect padding on the input tensor. + * The amount of padding on each side is specified by parameters stored in `dst->op_params`. + * The operation reflects the values at the borders of the tensor to generate the padded output. + * + * This operation is optimized using the CANN backend for high-performance inference or training. + * + * @param ctx The CANN context used for operations. + * @param dst The destination tensor where the padded result will be stored. + * dst->op is expected to be `GGML_OP_PAD_REFLECT_1D`. + */ +void ggml_cann_pad_reflect_1d(ggml_backend_cann_context& ctx, ggml_tensor* dst); + +/** + * @brief Counts the number of equal elements in two ggml tensors using the CANN backend. + * + * @details This function performs an element-wise comparison between two input tensors, + * and counts the number of positions where the elements are equal. The result is + * stored in the destination tensor `dst` as a scalar. + * + * The operation is optimized using the CANN backend, making it suitable for + * high-performance inference or training scenarios. + * + * @param ctx The CANN context used for operations. + * @param dst The destination tensor where the result will be stored. + * dst->op is expected to be `GGML_OP_COUNT_EQUAL`. + */ +void ggml_cann_count_equal(ggml_backend_cann_context& ctx, ggml_tensor* dst); + +/** + * @brief Applies the Step activation function to a ggml tensor using the CANN backend. + * + * @details This function applies a step function element-wise to the input tensor, where + * each element is transformed to 1.0 if it is greater than 0, and 0.0 otherwise. + * The result is stored in the destination tensor `dst`. + * + * This operation is accelerated using the CANN backend to improve runtime performance. + * + * @param ctx The CANN context used for operations. + * @param dst The destination tensor where the result will be stored. + * dst->op is expected to be `GGML_OP_STEP`. + */ +void ggml_cann_step(ggml_backend_cann_context& ctx, ggml_tensor* dst); + /** * @brief Applies a element-wise operation to two input tensors using the CANN * backend. diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index f9187ba81..b513270c6 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -1358,6 +1358,12 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx, case GGML_UNARY_OP_ELU: ggml_cann_elu(ctx, dst); break; + case GGML_UNARY_OP_SGN: + GGML_CANN_CALL_UNARY_OP(Sign); + break; + case GGML_UNARY_OP_STEP: + ggml_cann_step(ctx, dst); + break; default: return false; } @@ -1456,6 +1462,18 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx, case GGML_OP_CONV_TRANSPOSE_1D: ggml_cann_conv_transpose_1d(ctx, dst); break; + case GGML_OP_LOG: + GGML_CANN_CALL_UNARY_OP(Log); + break; + case GGML_OP_MEAN: + ggml_cann_mean(ctx, dst); + break; + case GGML_OP_PAD_REFLECT_1D: + ggml_cann_pad_reflect_1d(ctx, dst); + break; + case GGML_OP_COUNT_EQUAL: + ggml_cann_count_equal(ctx, dst); + break; default: return false; } @@ -1718,6 +1736,8 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, case GGML_UNARY_OP_TANH: case GGML_UNARY_OP_EXP: case GGML_UNARY_OP_ELU: + case GGML_UNARY_OP_SGN: + case GGML_UNARY_OP_STEP: return true; default: return false; @@ -1851,6 +1871,10 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, case GGML_OP_COS: case GGML_OP_SIN: case GGML_OP_CONV_TRANSPOSE_1D: + case GGML_OP_LOG: + case GGML_OP_MEAN: + case GGML_OP_PAD_REFLECT_1D: + case GGML_OP_COUNT_EQUAL: return true; default: return false; From 64eda5deb9859e87a020e56bab5d2f9ca956f1de Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Thu, 10 Apr 2025 17:24:44 +0200 Subject: [PATCH 20/20] convert : ability to lazy-load safetensors remotely without downloading to disk (#12820) * gguf util : add SafetensorRemote * fix style * convert: add --remote option * convert : allow using lazy remote tensors It's a bit slow for now since everything is blocking and single-threaded. * correct metadata.name * small style fix * support HF_TOKEN * convert : use writeable buffer for remote lazy tensors * convert : fix flake8 lint regarding lamdba assigment * multithreaded download * multithread: print debug * fix style * Revert "multithreaded download" This reverts commit 42fc895ace385edc972ad819c76c704aeea61791. * bring back _get_request_headers --------- Co-authored-by: Francis Couture-Harpin --- convert_hf_to_gguf.py | 56 ++++++++++-- gguf-py/gguf/utility.py | 195 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 244 insertions(+), 7 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 656dc9877..c9ac2957f 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -65,6 +65,7 @@ class Model: model_name: str | None metadata_override: Path | None dir_model_card: Path + remote_hf_model_id: str | None # subclasses should define this! model_arch: gguf.MODEL_ARCH @@ -73,7 +74,7 @@ class Model: use_temp_file: bool = False, eager: bool = False, metadata_override: Path | None = None, model_name: str | None = None, split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, - small_first_shard: bool = False, hparams: dict[str, Any] | None = None): + small_first_shard: bool = False, hparams: dict[str, Any] | None = None, remote_hf_model_id: str | None = None): if type(self) is Model: raise TypeError(f"{type(self).__name__!r} should not be directly instantiated") @@ -83,11 +84,24 @@ class Model: self.is_big_endian = is_big_endian self.endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE self.use_temp_file = use_temp_file - self.lazy = not eager - self.part_names = Model.get_model_part_names(self.dir_model, "model", ".safetensors") - self.is_safetensors = len(self.part_names) > 0 - if not self.is_safetensors: - self.part_names = Model.get_model_part_names(self.dir_model, "pytorch_model", ".bin") + self.lazy = not eager or (remote_hf_model_id is not None) + self.remote_hf_model_id = remote_hf_model_id + if remote_hf_model_id is not None: + self.is_safetensors = True + + def get_remote_tensors() -> Iterator[tuple[str, Tensor]]: + logger.info(f"Using remote model with HuggingFace id: {remote_hf_model_id}") + remote_tensors = gguf.utility.SafetensorRemote.get_list_tensors_hf_model(remote_hf_model_id) + self.tensor_names = set(name for name in remote_tensors.keys()) + for name, remote_tensor in gguf.utility.SafetensorRemote.get_list_tensors_hf_model(remote_hf_model_id).items(): + yield (name, LazyTorchTensor.from_remote_tensor(remote_tensor)) + + self.get_tensors = get_remote_tensors + else: + self.part_names = Model.get_model_part_names(self.dir_model, "model", ".safetensors") + self.is_safetensors = len(self.part_names) > 0 + if not self.is_safetensors: + self.part_names = Model.get_model_part_names(self.dir_model, "pytorch_model", ".bin") self.hparams = Model.load_hparams(self.dir_model) if hparams is None else hparams self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers"]) self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) @@ -393,6 +407,10 @@ class Model: self.metadata = gguf.Metadata.load(self.metadata_override, self.dir_model_card, self.model_name, total_params) + # If we are using HF model id, set the metadata name to the model id + if self.remote_hf_model_id: + self.metadata.name = self.remote_hf_model_id + # Fallback to model directory name if metadata name is still missing if self.metadata.name is None: self.metadata.name = self.dir_model.name @@ -5403,6 +5421,14 @@ class LazyTorchTensor(gguf.LazyBase): lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(st_slice,), func=lambda s: s[:]) return cast(torch.Tensor, lazy) + @classmethod + def from_remote_tensor(cls, remote_tensor: gguf.utility.RemoteTensor): + dtype = cls._dtype_str_map[remote_tensor.dtype] + shape = remote_tensor.shape + meta = cls.meta_with_dtype_and_shape(dtype, shape) + lazy = cls(meta=meta, args=(remote_tensor,), func=lambda r: torch.frombuffer(r.data(), dtype=dtype).reshape(shape)) + return cast(torch.Tensor, lazy) + @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): del types # unused @@ -5480,6 +5506,10 @@ def parse_args() -> argparse.Namespace: "--print-supported-models", action="store_true", help="Print the supported models" ) + parser.add_argument( + "--remote", action="store_true", + help="(Experimental) Read safetensors file remotely without downloading to disk. Config and tokenizer files will still be downloaded. To use this feature, you need to specify Hugging Face model repo name instead of a local directory. For example: 'HuggingFaceTB/SmolLM2-1.7B-Instruct'. Note: To access gated repo, set HF_TOKEN environment variable to your Hugging Face token.", + ) args = parser.parse_args() if not args.print_supported_models and args.model is None: @@ -5520,6 +5550,14 @@ def main() -> None: dir_model = args.model + if args.remote: + from huggingface_hub import snapshot_download + local_dir = snapshot_download( + repo_id=str(dir_model), + allow_patterns=["LICENSE", "*.json", "*.md", "*.txt", "tokenizer.model"]) + dir_model = Path(local_dir) + logger.info(f"Downloaded config and tokenizer to {local_dir}") + if not dir_model.is_dir(): logger.error(f'Error: {args.model} is not a directory') sys.exit(1) @@ -5541,6 +5579,9 @@ def main() -> None: if args.outfile is not None: fname_out = args.outfile + elif args.remote: + # if remote, use the model ID as the output file name + fname_out = Path("./" + str(args.model).replace("/", "-") + "-{ftype}.gguf") else: fname_out = dir_model @@ -5564,7 +5605,8 @@ def main() -> None: metadata_override=args.metadata, model_name=args.model_name, split_max_tensors=args.split_max_tensors, split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run, - small_first_shard=args.no_tensor_first_split) + small_first_shard=args.no_tensor_first_split, + remote_hf_model_id=str(args.model) if args.remote else None) if args.vocab_only: logger.info("Exporting model vocab...") diff --git a/gguf-py/gguf/utility.py b/gguf-py/gguf/utility.py index ae92d786a..e5251aef8 100644 --- a/gguf-py/gguf/utility.py +++ b/gguf-py/gguf/utility.py @@ -1,7 +1,11 @@ from __future__ import annotations +from dataclasses import dataclass from typing import Literal +import os +import json + def fill_templated_filename(filename: str, output_type: str | None) -> str: # Given a file name fill in any type templates e.g. 'some-model-name.{ftype}.gguf' @@ -67,3 +71,194 @@ def naming_convention(model_name: str | None, base_name: str | None, finetune_st kind = f"-{model_type.strip().replace(' ', '-')}" if model_type is not None else "" return f"{name}{parameters}{finetune}{version}{encoding}{kind}" + + +@dataclass +class RemoteTensor: + dtype: str + shape: tuple[int, ...] + offset_start: int + size: int + url: str + + def data(self) -> bytearray: + # TODO: handle request errors (maybe with limited retries?) + # NOTE: using a bytearray, otherwise PyTorch complains the buffer is not writeable + data = bytearray(SafetensorRemote.get_data_by_range(url=self.url, start=self.offset_start, size=self.size)) + return data + + +class SafetensorRemote: + """ + Uility class to handle remote safetensor files. + This class is designed to work with Hugging Face model repositories. + + Example (one model has single safetensor file, the other has multiple): + for model_id in ["ngxson/TEST-Tiny-Llama4", "Qwen/Qwen2.5-7B-Instruct"]: + tensors = SafetensorRemote.get_list_tensors_hf_model(model_id) + print(tensors) + + Example reading tensor data: + tensors = SafetensorRemote.get_list_tensors_hf_model(model_id) + for name, meta in tensors.items(): + dtype, shape, offset_start, size, remote_safetensor_url = meta + # read the tensor data + data = SafetensorRemote.get_data_by_range(remote_safetensor_url, offset_start, size) + print(data) + """ + + BASE_DOMAIN = "https://huggingface.co" + ALIGNMENT = 8 # bytes + + @classmethod + def get_list_tensors_hf_model(cls, model_id: str) -> dict[str, RemoteTensor]: + """ + Get list of tensors from a Hugging Face model repository. + + Returns a dictionary of tensor names and their metadata. + Each tensor is represented as a tuple of (dtype, shape, offset_start, size, remote_safetensor_url) + """ + # case 1: model has only one single model.safetensor file + is_single_file = cls.check_file_exist(f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors") + if is_single_file: + url = f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors" + return cls.get_list_tensors(url) + + # case 2: model has multiple files + index_url = f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors.index.json" + is_multiple_files = cls.check_file_exist(index_url) + if is_multiple_files: + # read the index file + index_data = cls.get_data_by_range(index_url, 0) + index_str = index_data.decode('utf-8') + index_json = json.loads(index_str) + assert index_json.get("weight_map") is not None, "weight_map not found in index file" + weight_map = index_json["weight_map"] + # get the list of files + all_files = list(set(weight_map.values())) + all_files.sort() # make sure we load shard files in order + # get the list of tensors + tensors: dict[str, RemoteTensor] = {} + for file in all_files: + url = f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/{file}" + for key, val in cls.get_list_tensors(url).items(): + tensors[key] = val + return tensors + + raise ValueError(f"Model {model_id} does not have any safetensor files") + + @classmethod + def get_list_tensors(cls, url: str) -> dict[str, RemoteTensor]: + """ + Get list of tensors from a remote safetensor file. + + Returns a dictionary of tensor names and their metadata. + Each tensor is represented as a tuple of (dtype, shape, offset_start, size) + """ + metadata, data_start_offset = cls.get_metadata(url) + res: dict[str, RemoteTensor] = {} + + for name, meta in metadata.items(): + if name == "__metadata__": + continue + if not isinstance(meta, dict): + raise ValueError(f"Invalid metadata for tensor '{name}': {meta}") + try: + dtype = meta["dtype"] + shape = meta["shape"] + offset_start_relative, offset_end_relative = meta["data_offsets"] + size = offset_end_relative - offset_start_relative + offset_start = data_start_offset + offset_start_relative + res[name] = RemoteTensor(dtype=dtype, shape=tuple(shape), offset_start=offset_start, size=size, url=url) + except KeyError as e: + raise ValueError(f"Missing key in metadata for tensor '{name}': {e}, meta = {meta}") + + return res + + @classmethod + def get_metadata(cls, url: str) -> tuple[dict, int]: + """ + Get JSON metadata from a remote safetensor file. + + Returns tuple of (metadata, data_start_offset) + """ + # Request first 5MB of the file (hopefully enough for metadata) + read_size = 5 * 1024 * 1024 + raw_data = cls.get_data_by_range(url, 0, read_size) + + # Parse header + # First 8 bytes contain the metadata length as u64 little-endian + if len(raw_data) < 8: + raise ValueError("Not enough data to read metadata size") + metadata_length = int.from_bytes(raw_data[:8], byteorder='little') + + # Calculate the data start offset + data_start_offset = 8 + metadata_length + alignment = SafetensorRemote.ALIGNMENT + if data_start_offset % alignment != 0: + data_start_offset += alignment - (data_start_offset % alignment) + + # Check if we have enough data to read the metadata + if len(raw_data) < 8 + metadata_length: + raise ValueError(f"Could not read complete metadata. Need {8 + metadata_length} bytes, got {len(raw_data)}") + + # Extract metadata bytes and parse as JSON + metadata_bytes = raw_data[8:8 + metadata_length] + metadata_str = metadata_bytes.decode('utf-8') + try: + metadata = json.loads(metadata_str) + return metadata, data_start_offset + except json.JSONDecodeError as e: + raise ValueError(f"Failed to parse safetensor metadata as JSON: {e}") + + @classmethod + def get_data_by_range(cls, url: str, start: int, size: int = -1) -> bytes: + """ + Get raw byte data from a remote file by range. + If size is not specified, it will read the entire file. + """ + import requests + from urllib.parse import urlparse + + parsed_url = urlparse(url) + if not parsed_url.scheme or not parsed_url.netloc: + raise ValueError(f"Invalid URL: {url}") + + headers = cls._get_request_headers() + if size > -1: + headers["Range"] = f"bytes={start}-{start + size}" + response = requests.get(url, allow_redirects=True, headers=headers) + response.raise_for_status() + + # Get raw byte data + return response.content[:size] + + @classmethod + def check_file_exist(cls, url: str) -> bool: + """ + Check if a file exists at the given URL. + Returns True if the file exists, False otherwise. + """ + import requests + from urllib.parse import urlparse + + parsed_url = urlparse(url) + if not parsed_url.scheme or not parsed_url.netloc: + raise ValueError(f"Invalid URL: {url}") + + try: + headers = cls._get_request_headers() + headers["Range"] = "bytes=0-0" + response = requests.head(url, allow_redirects=True, headers=headers) + # Success (2xx) or redirect (3xx) + return 200 <= response.status_code < 400 + except requests.RequestException: + return False + + @classmethod + def _get_request_headers(cls) -> dict[str, str]: + """Prepare common headers for requests.""" + headers = {"User-Agent": "convert_hf_to_gguf"} + if os.environ.get("HF_TOKEN"): + headers["Authorization"] = f"Bearer {os.environ['HF_TOKEN']}" + return headers