diff --git a/common/arg.cpp b/common/arg.cpp index 44984f687..104a6f8ca 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -2657,6 +2657,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.i_chunk = value; } ).set_examples({LLAMA_EXAMPLE_IMATRIX})); + add_opt(common_arg( + {"--show-statistics"}, + string_format("show imatrix statistics and then exit (default: %s)", params.show_statistics ? "true" : "false"), + [](common_params & params) { + params.show_statistics = true; + } + ).set_examples({LLAMA_EXAMPLE_IMATRIX})); add_opt(common_arg( {"--parse-special"}, string_format("prase special tokens (chat, tool, etc) (default: %s)", params.parse_special ? "true" : "false"), diff --git a/common/common.h b/common/common.h index 0e8fa156f..6ba1df613 100644 --- a/common/common.h +++ b/common/common.h @@ -428,9 +428,10 @@ struct common_params { int32_t n_save_freq = 0; // save the imatrix every n_save_freq iterations int32_t i_chunk = 0; // start processing from this chunk - bool process_output = false; // collect data for the output tensor - bool compute_ppl = true; // whether to compute perplexity - bool parse_special = false; // whether to parse special tokens during imatrix tokenization + bool process_output = false; // collect data for the output tensor + bool compute_ppl = true; // whether to compute perplexity + bool show_statistics = false; // show imatrix statistics per tensor + bool parse_special = false; // whether to parse special tokens during imatrix tokenization // cvector-generator params int n_pca_batch = 100; diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index c8bf3c538..e12c922bd 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -6486,7 +6486,7 @@ class JaisModel(TextModel): self.gguf_writer.add_max_alibi_bias(self.max_alibi_bias) -@ModelBase.register("Glm4ForCausalLM") +@ModelBase.register("Glm4ForCausalLM", "Glm4vForConditionalGeneration") class Glm4Model(TextModel): model_arch = gguf.MODEL_ARCH.GLM4 @@ -6508,7 +6508,8 @@ class Glm4Model(TextModel): def set_gguf_parameters(self): super().set_gguf_parameters() - rope_dim = self.hparams["head_dim"] + if (rope_dim := self.hparams.get("head_dim")) is None: + rope_dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"] self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.hparams.get("partial_rotary_factor", 0.5))) rope_scaling = self.hparams.get("rope_scaling") or {} if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling: @@ -6516,6 +6517,13 @@ class Glm4Model(TextModel): self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"]) + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + if name.startswith("model.visual."): # ignore visual part of Glm4v + return [] + elif name.startswith("model.language_model."): + name = name.replace("language_model.", "") # for Glm4v + return super().modify_tensors(data_torch, name, bid) + @ModelBase.register("GlmForCausalLM", "ChatGLMModel", "ChatGLMForConditionalGeneration") class ChatGLMModel(TextModel): diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index 04a35806b..601b2a3bc 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -647,6 +647,7 @@ struct ggml_backend_sched { // pipeline parallelism support int n_copies; int cur_copy; + int next_copy; ggml_backend_event_t events[GGML_SCHED_MAX_BACKENDS][GGML_SCHED_MAX_COPIES]; struct ggml_tensor * graph_inputs[GGML_SCHED_MAX_SPLIT_INPUTS]; int n_graph_inputs; @@ -1439,8 +1440,6 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s } } - sched->cur_copy = (sched->cur_copy + 1) % sched->n_copies; - return GGML_STATUS_SUCCESS; } @@ -1541,10 +1540,10 @@ void ggml_backend_sched_reset(ggml_backend_sched_t sched) { bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph) { GGML_ASSERT((int)sched->hash_set.size >= measure_graph->n_nodes + measure_graph->n_leafs); - ggml_backend_sched_split_graph(sched, measure_graph); - ggml_backend_sched_synchronize(sched); + ggml_backend_sched_split_graph(sched, measure_graph); + if (!ggml_gallocr_reserve_n(sched->galloc, &sched->graph, sched->node_backend_ids, sched->leaf_backend_ids)) { return false; } @@ -1556,6 +1555,10 @@ bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * bool ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) { GGML_ASSERT((int)sched->hash_set.size >= graph->n_nodes + graph->n_leafs); + GGML_ASSERT(!sched->is_alloc); + + sched->cur_copy = sched->next_copy; + sched->next_copy = (sched->next_copy + 1) % sched->n_copies; ggml_backend_sched_split_graph(sched, graph); @@ -1596,7 +1599,7 @@ void ggml_backend_sched_synchronize(ggml_backend_sched_t sched) { // if the graph is not already allocated, always use copy 0 after a synchronization // this ensures that during generation the same copy is used every time, // which avoids changes in the graph that could cause CUDA or other graphs to be disabled - sched->cur_copy = 0; + sched->next_copy = 0; } } diff --git a/ggml/src/ggml-cpu/arch/loongarch/quants.c b/ggml/src/ggml-cpu/arch/loongarch/quants.c index 9e33fb322..7908da4d1 100644 --- a/ggml/src/ggml-cpu/arch/loongarch/quants.c +++ b/ggml/src/ggml-cpu/arch/loongarch/quants.c @@ -544,7 +544,7 @@ void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, i __m128 max4 = __lsx_vfmax_s( lasx_extractf128( max_abs, 1 ), lasx_extractf128( max_abs, 0) ); max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vpickod_d((__m128i) max4, (__m128i)max4 ) ); __m128 tmp = max4; - max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vextrins_w((__m128i)tmp, (__m128i)max4, 0x10 )); + max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vextrins_w((__m128i)tmp, (__m128i)max4, 0x1 )); const float max_scalar = ((v4f32)max4)[0]; // Quantize these floats diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp index a7c6e7eee..44952aea9 100644 --- a/ggml/src/ggml-cpu/repack.cpp +++ b/ggml/src/ggml-cpu/repack.cpp @@ -14,7 +14,6 @@ #include #include #include -#include // for qsort #include // for GGML_ASSERT #include "repack.h" diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 0bcbc494f..d416f23ab 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -769,7 +769,7 @@ struct ggml_tensor_extra_gpu { }; -#if (defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS)) +#if (defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS)) || defined(GGML_MUSA_GRAPHS) #define USE_CUDA_GRAPH #endif diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index eeaa14bf5..15c927861 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -6,24 +6,33 @@ #define CUDA_Q8_0_NE_ALIGN 2048 template -static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k) { - const int64_t i = (int64_t)2*(blockDim.x*blockIdx.x + threadIdx.x); +static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, + const int64_t ne00, const int64_t ne01, const int64_t ne02, + const int64_t s01, const int64_t s02, const int64_t s03) { + const int64_t i00 = 2 * (int64_t(blockDim.x)*blockIdx.x + threadIdx.x); - if (i >= k) { + if (i00 >= ne00) { return; } - const int64_t ib = i/qk; // block index - const int64_t iqs = (i%qk)/qr; // quant index - const int64_t iybs = i - i%qk; // y block start index + const int64_t i01 = blockIdx.y; + const int64_t i02 = blockIdx.z % ne02; + const int64_t i03 = blockIdx.z / ne02; + + const int64_t ibx0 = i03*s03 + i02*s02 + i01*s01; + + const int64_t ib = ibx0 + i00/qk; // block index + const int64_t iqs = (i00%qk)/qr; // quant index + const int64_t iybs = i00 - i00%qk; // y block start index const int64_t y_offset = qr == 1 ? 1 : qk/2; // dequantize dfloat2 v; dequantize_kernel(vx, ib, iqs, v); - y[iybs + iqs + 0] = v.x; - y[iybs + iqs + y_offset] = v.y; + const int64_t iy0 = ((i03*ne02 + i02)*ne01 + i01)*ne00 + iybs + iqs; + y[iy0 + 0] = float(v.x); + y[iy0 + y_offset] = float(v.y); } template @@ -457,9 +466,17 @@ static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst } template -static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) { - const int num_blocks = (k + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE); - dequantize_block<<>>(vx, y, k); +static void dequantize_block_cuda(const void * vx, dst_t * y, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const int64_t s01, const int64_t s02, const int64_t s03, cudaStream_t stream) { + const dim3 num_blocks((ne00 + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE), ne01, ne02*ne03); + dequantize_block<<>> + (vx, y, ne00, ne01, ne02, s01, s02, s03); +} + +template +static void dequantize_block_cont_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) { + dequantize_block_cuda(vx, y, k, 1, 1, 1, k/qk, k/qk, k/qk, stream); } static void dequantize_block_q8_0_f16_cuda(const void * __restrict__ vx, half * __restrict__ y, const int64_t k, cudaStream_t stream) { @@ -624,14 +641,14 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { case GGML_TYPE_Q4_1: return dequantize_row_q4_1_cuda; case GGML_TYPE_Q5_0: - return dequantize_block_cuda; + return dequantize_block_cont_cuda; case GGML_TYPE_Q5_1: - return dequantize_block_cuda; + return dequantize_block_cont_cuda; case GGML_TYPE_Q8_0: if (fp16_available(ggml_cuda_info().devices[ggml_cuda_get_device()].cc)) { return dequantize_block_q8_0_f16_cuda; } - return dequantize_block_cuda; + return dequantize_block_cont_cuda; case GGML_TYPE_Q2_K: return dequantize_row_q2_K_cuda; case GGML_TYPE_Q3_K: @@ -676,11 +693,11 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { case GGML_TYPE_Q4_1: return dequantize_row_q4_1_cuda; case GGML_TYPE_Q5_0: - return dequantize_block_cuda; + return dequantize_block_cont_cuda; case GGML_TYPE_Q5_1: - return dequantize_block_cuda; + return dequantize_block_cont_cuda; case GGML_TYPE_Q8_0: - return dequantize_block_cuda; + return dequantize_block_cont_cuda; case GGML_TYPE_Q2_K: return dequantize_row_q2_K_cuda; case GGML_TYPE_Q3_K: @@ -722,6 +739,16 @@ to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type) { switch (type) { case GGML_TYPE_F32: return convert_unary_cuda; + case GGML_TYPE_Q4_0: + return dequantize_block_cuda; + case GGML_TYPE_Q4_1: + return dequantize_block_cuda; + case GGML_TYPE_Q5_0: + return dequantize_block_cuda; + case GGML_TYPE_Q5_1: + return dequantize_block_cuda; + case GGML_TYPE_Q8_0: + return dequantize_block_cuda; case GGML_TYPE_BF16: return convert_unary_cuda; default: @@ -733,6 +760,16 @@ to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type) { switch (type) { case GGML_TYPE_F32: return convert_unary_cuda; + case GGML_TYPE_Q4_0: + return dequantize_block_cuda; + case GGML_TYPE_Q4_1: + return dequantize_block_cuda; + case GGML_TYPE_Q5_0: + return dequantize_block_cuda; + case GGML_TYPE_Q5_1: + return dequantize_block_cuda; + case GGML_TYPE_Q8_0: + return dequantize_block_cuda; case GGML_TYPE_F16: return convert_unary_cuda; default: @@ -744,6 +781,16 @@ to_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda(ggml_type type) { switch (type) { case GGML_TYPE_F16: return convert_unary_cuda; + case GGML_TYPE_Q4_0: + return dequantize_block_cuda; + case GGML_TYPE_Q4_1: + return dequantize_block_cuda; + case GGML_TYPE_Q5_0: + return dequantize_block_cuda; + case GGML_TYPE_Q5_1: + return dequantize_block_cuda; + case GGML_TYPE_Q8_0: + return dequantize_block_cuda; case GGML_TYPE_BF16: return convert_unary_cuda; default: diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index 0e5964907..f9bb02564 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -1,9 +1,9 @@ #include "cpy.cuh" #include "dequantize.cuh" #include "cpy-utils.cuh" -#ifdef GGML_USE_MUSA +#if defined(GGML_USE_MUSA) && defined(GGML_MUSA_MUDNN_COPY) #include "ggml-musa/mudnn.cuh" -#endif // GGML_USE_MUSA +#endif // GGML_USE_MUSA && GGML_MUSA_MUDNN_COPY typedef void (*cpy_kernel_t)(const char * cx, char * cdst); @@ -121,7 +121,7 @@ static __global__ void cpy_q_f32(const char * cx, char * cdst_direct, const int // Copy destination pointers to GPU to be available when pointer indirection is in use void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_dest_ptrs, const int host_dest_ptrs_size, cudaStream_t stream) { -#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) +#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS) if (cuda_graph->dest_ptrs_size < host_dest_ptrs_size) { // (re-)allocate GPU memory for destination pointers CUDA_CHECK(cudaStreamSynchronize(stream)); if (cuda_graph->dest_ptrs_d != nullptr) { @@ -314,7 +314,7 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg char ** dest_ptrs_d = nullptr; int graph_cpynode_index = -1; -#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) +#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS) if(ctx.cuda_graph->use_cpy_indirection && !disable_indirection_for_this_node) { dest_ptrs_d = ctx.cuda_graph->dest_ptrs_d; graph_cpynode_index = ctx.cuda_graph->graph_cpynode_index; @@ -324,11 +324,11 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg #endif if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) { GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1)); -#ifdef GGML_USE_MUSA +#if defined(GGML_USE_MUSA) && defined(GGML_MUSA_MUDNN_COPY) if (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16) { CUDA_CHECK(mudnnMemcpyAsync(ctx, src1, src0)); } else -#endif // GGML_USE_MUSA +#endif // GGML_USE_MUSA && GGML_MUSA_MUDNN_COPY { CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream)); } @@ -379,7 +379,7 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__, ggml_type_name(src0->type), ggml_type_name(src1->type)); } -#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) +#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS) if(ctx.cuda_graph->use_cpy_indirection && !disable_indirection_for_this_node) { ctx.cuda_graph->graph_cpynode_index = graph_cpynode_index; } diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index 9122fca6c..95e704e39 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -23,33 +23,13 @@ typedef void (* fattn_kernel_t)( const float m1, const uint32_t n_head_log2, const float logit_softcap, - const int ne00, - const int ne01, - const int ne02, - const int ne03, - const int ne10, - const int ne11, - const int ne12, - const int ne13, - const int ne31, - const int ne32, - const int ne33, - const int nb31, - const int nb32, - const int nb33, - const int nb01, - const int nb02, - const int nb03, - const int nb11, - const int nb12, - const int nb13, - const int nb21, - const int nb22, - const int nb23, - const int ne0, - const int ne1, - const int ne2, - const int ne3); + const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, + const int32_t nb01, const int32_t nb02, const int32_t nb03, + const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, + const int32_t nb11, const int32_t nb12, const int64_t nb13, + const int32_t nb21, const int32_t nb22, const int64_t nb23, + const int32_t ne31, const int32_t ne32, const int32_t ne33, + const int32_t nb31, const int32_t nb32, const int64_t nb33); typedef half (*vec_dot_KQ_f16_t)( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds); @@ -745,33 +725,58 @@ void launch_fattn( size_t nb23 = V ? V->nb[3] : nb13; if (need_f16_K && K->type != GGML_TYPE_F16) { - GGML_ASSERT(ggml_is_contiguously_allocated(K)); - K_f16.alloc(ggml_nelements(K)); - to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type); - to_fp16(K_data, K_f16.ptr, ggml_nelements(K), main_stream); - K_data = (char *) K_f16.ptr; - const size_t bs = ggml_blck_size(K->type); const size_t ts = ggml_type_size(K->type); - nb11 = nb11*bs*sizeof(half)/ts; - nb12 = nb12*bs*sizeof(half)/ts; - nb13 = nb13*bs*sizeof(half)/ts; + K_f16.alloc(ggml_nelements(K)); + if (ggml_is_contiguously_allocated(K)) { + to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type); + to_fp16(K_data, K_f16.ptr, ggml_nelements(K), main_stream); + + nb11 = nb11*bs*sizeof(half)/ts; + nb12 = nb12*bs*sizeof(half)/ts; + nb13 = nb13*bs*sizeof(half)/ts; + } else { + GGML_ASSERT(K->nb[0] == ts); + to_fp16_nc_cuda_t to_fp16 = ggml_get_to_fp16_nc_cuda(K->type); + const int64_t s01 = nb11 / ts; + const int64_t s02 = nb12 / ts; + const int64_t s03 = nb13 / ts; + to_fp16(K_data, K_f16.ptr, K->ne[0], K->ne[1], K->ne[2], K->ne[3], s01, s02, s03, main_stream); + + nb11 = K->ne[0] * sizeof(half); + nb12 = K->ne[1] * nb11; + nb13 = K->ne[2] * nb12; + } + K_data = (char *) K_f16.ptr; } if (V && need_f16_V && V->type != GGML_TYPE_F16) { - GGML_ASSERT(ggml_is_contiguously_allocated(V)); - V_f16.alloc(ggml_nelements(V)); - to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type); - to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream); - V_data = (char *) V_f16.ptr; - const size_t bs = ggml_blck_size(V->type); const size_t ts = ggml_type_size(V->type); - nb21 = nb21*bs*sizeof(half)/ts; - nb22 = nb22*bs*sizeof(half)/ts; - nb23 = nb23*bs*sizeof(half)/ts; + V_f16.alloc(ggml_nelements(V)); + if (ggml_is_contiguously_allocated(V)) { + to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type); + to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream); + V_data = (char *) V_f16.ptr; + + nb21 = nb21*bs*sizeof(half)/ts; + nb22 = nb22*bs*sizeof(half)/ts; + nb23 = nb23*bs*sizeof(half)/ts; + } else { + GGML_ASSERT(V->nb[0] == ts); + to_fp16_nc_cuda_t to_fp16 = ggml_get_to_fp16_nc_cuda(V->type); + const int64_t s01 = nb21 / ts; + const int64_t s02 = nb22 / ts; + const int64_t s03 = nb23 / ts; + to_fp16(V_data, V_f16.ptr, V->ne[0], V->ne[1], V->ne[2], V->ne[3], s01, s02, s03, main_stream); + + nb21 = V->ne[0] * sizeof(half); + nb22 = V->ne[1] * nb21; + nb23 = V->ne[2] * nb22; + } + V_data = (char *) V_f16.ptr; } int parallel_blocks = 1; @@ -867,14 +872,11 @@ void launch_fattn( mask ? ((const char *) mask->data) : nullptr, !stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr, scale, max_bias, m0, m1, n_head_log2, logit_softcap, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0, - mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0, mask ? mask->nb[3] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - nb11, nb12, nb13, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], Q->nb[1], Q->nb[2], Q->nb[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], nb11, nb12, nb13, nb21, nb22, nb23, - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0, + mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0, mask ? mask->nb[3] : 0 ); CUDA_CHECK(cudaGetLastError()); diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 6fa2e7729..565853bfe 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -408,7 +408,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( const int stride_K, const int stride_V, const int stride_mask, - const int jt, half2 * const __restrict__ tile_Q, half2 * const __restrict__ tile_K, half2 * const __restrict__ tile_V, @@ -455,7 +454,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( cp_async_wait_all(); __syncthreads(); flash_attn_ext_f16_load_tile - (V_h2 + k_VKQ_0*stride_V, tile_V, nbatch_V2, stride_V); + (V_h2 + int64_t(k_VKQ_0)*stride_V, tile_V, nbatch_V2, stride_V); } else { constexpr bool use_cp_async = nstages == 1; if (ncols2 > 1 || mask_h2) { @@ -471,7 +470,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( if (nstages <= 1) { constexpr bool use_cp_async = nstages == 1; flash_attn_ext_f16_load_tile - (K_h2 + k_VKQ_0*stride_K + k0_start, tile_K, k0_diff, stride_K); + (K_h2 + int64_t(k_VKQ_0)*stride_K + k0_start, tile_K, k0_diff, stride_K); if (use_cp_async) { cp_async_wait_all(); } @@ -715,7 +714,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( (mask_h2 + (k_VKQ_0 + c::nbatch_fa)/2, tile_mask, stride_mask); } flash_attn_ext_f16_load_tile - (K_h2 + (k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K); + (K_h2 + int64_t(k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K); } } @@ -732,7 +731,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( if (nstages <= 1 && i0_start < reusable_cutoff) { constexpr bool use_cp_async = nstages == 1; flash_attn_ext_f16_load_tile - (V_h2 + k_VKQ_0*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V); + (V_h2 + int64_t(k_VKQ_0)*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V); if (use_cp_async) { cp_async_wait_all(); } @@ -771,8 +770,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup); GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_K); GGML_UNUSED(stride_V); - GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K); - GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K); + GGML_UNUSED(stride_mask); GGML_UNUSED(tile_K); GGML_UNUSED(tile_V); GGML_UNUSED(tile_mask); GGML_UNUSED(Q_B); GGML_UNUSED(VKQ_C); GGML_UNUSED(KQ_max); GGML_UNUSED(KQ_rowsum); GGML_UNUSED(kb0); GGML_UNUSED(tile_Q); @@ -920,7 +918,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( (mask_h2 + kb0_start*c::nbatch_fa/2, tile_mask, stride_mask); } flash_attn_ext_f16_load_tile - (K_h2 + kb0_start*c::nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K); + (K_h2 + int64_t(kb0_start)*c::nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K); } // Iterate over ne11 == previous tokens: @@ -928,13 +926,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( constexpr bool last_iter = false; flash_attn_ext_f16_iter (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, - ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0); + ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0); } { // kb0_start is always < kb0_stop so the last iter can be executed unconditionally. constexpr bool last_iter = true; flash_attn_ext_f16_iter (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, - ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1); + ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1); } // With multi-stage loading there is no __syncthreads at the end of the iter, @@ -1214,33 +1212,13 @@ static __global__ void flash_attn_ext_f16( const float m1, const uint32_t n_head_log2, const float logit_softcap, - const int ne00, - const int ne01, - const int ne02, - const int ne03, - const int ne10, - const int ne11, - const int ne12, - const int ne13, - const int ne31, - const int ne32, - const int ne33, - const int nb31, - const int nb32, - const int nb33, - const int nb01, - const int nb02, - const int nb03, - const int nb11, - const int nb12, - const int nb13, - const int nb21, - const int nb22, - const int nb23, - const int ne0, - const int ne1, - const int ne2, - const int ne3) { + const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, + const int32_t nb01, const int32_t nb02, const int32_t nb03, + const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, + const int32_t nb11, const int32_t nb12, const int64_t nb13, + const int32_t nb21, const int32_t nb22, const int64_t nb23, + const int32_t ne31, const int32_t ne32, const int32_t ne33, + const int32_t nb31, const int32_t nb32, const int64_t nb33) { #if defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE) // Skip unused kernel variants for faster compilation: @@ -1359,8 +1337,7 @@ static __global__ void flash_attn_ext_f16( GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21); - GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); - GGML_UNUSED(ne2); GGML_UNUSED(ne3); + GGML_UNUSED(nb22); GGML_UNUSED(nb23); NO_DEVICE_CODE; #endif // defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE) } diff --git a/ggml/src/ggml-cuda/fattn-tile-f16.cu b/ggml/src/ggml-cuda/fattn-tile-f16.cu index eadb943ca..24473ad51 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f16.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f16.cu @@ -21,33 +21,13 @@ static __global__ void flash_attn_tile_ext_f16( const float m1, const uint32_t n_head_log2, const float logit_softcap, - const int ne00, - const int ne01, - const int ne02, - const int ne03, - const int ne10, - const int ne11, - const int ne12, - const int ne13, - const int ne31, - const int ne32, - const int ne33, - const int nb31, - const int nb32, - const int nb33, - const int nb01, - const int nb02, - const int nb03, - const int nb11, - const int nb12, - const int nb13, - const int nb21, - const int nb22, - const int nb23, - const int ne0, - const int ne1, - const int ne2, - const int ne3) { + const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, + const int32_t nb01, const int32_t nb02, const int32_t nb03, + const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, + const int32_t nb11, const int32_t nb12, const int64_t nb13, + const int32_t nb21, const int32_t nb22, const int64_t nb23, + const int32_t ne31, const int32_t ne32, const int32_t ne33, + const int32_t nb31, const int32_t nb32, const int64_t nb33) { #if defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE) // Skip unused kernel variants for faster compilation: @@ -127,7 +107,7 @@ static __global__ void flash_attn_tile_ext_f16( for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) { const int k_KQ = k_KQ_0 + threadIdx.x; - KV_tmp[i_KQ][k_KQ] = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ]; + KV_tmp[i_KQ][k_KQ] = K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ]; } } @@ -221,7 +201,7 @@ static __global__ void flash_attn_tile_ext_f16( for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; - KV_tmp[k][i] = V_h2[(k_VKQ_0 + k)*stride_KV2 + i]; + KV_tmp[k][i] = V_h2[int64_t(k_VKQ_0 + k)*stride_KV2 + i]; } } @@ -300,8 +280,7 @@ static __global__ void flash_attn_tile_ext_f16( GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); - GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); - GGML_UNUSED(ne2); GGML_UNUSED(ne3); + GGML_UNUSED(nb23); NO_DEVICE_CODE; #endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE) } diff --git a/ggml/src/ggml-cuda/fattn-tile-f32.cu b/ggml/src/ggml-cuda/fattn-tile-f32.cu index a4965583c..2e2ed5cd5 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f32.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f32.cu @@ -21,33 +21,13 @@ static __global__ void flash_attn_tile_ext_f32( const float m1, const uint32_t n_head_log2, const float logit_softcap, - const int ne00, - const int ne01, - const int ne02, - const int ne03, - const int ne10, - const int ne11, - const int ne12, - const int ne13, - const int ne31, - const int ne32, - const int ne33, - const int nb31, - const int nb32, - const int nb33, - const int nb01, - const int nb02, - const int nb03, - const int nb11, - const int nb12, - const int nb13, - const int nb21, - const int nb22, - const int nb23, - const int ne0, - const int ne1, - const int ne2, - const int ne3) { + const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, + const int32_t nb01, const int32_t nb02, const int32_t nb03, + const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, + const int32_t nb11, const int32_t nb12, const int64_t nb13, + const int32_t nb21, const int32_t nb22, const int64_t nb23, + const int32_t ne31, const int32_t ne32, const int32_t ne33, + const int32_t nb31, const int32_t nb32, const int64_t nb33) { #ifdef FLASH_ATTN_AVAILABLE // Skip unused kernel variants for faster compilation: @@ -66,8 +46,7 @@ static __global__ void flash_attn_tile_ext_f32( GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); - GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); - GGML_UNUSED(ne2); GGML_UNUSED(ne3); + GGML_UNUSED(nb23); NO_DEVICE_CODE; return; } @@ -135,7 +114,7 @@ static __global__ void flash_attn_tile_ext_f32( #pragma unroll for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 2*WARP_SIZE) { - const half2 tmp = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + threadIdx.x]; + const half2 tmp = K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + threadIdx.x]; KV_tmp[i_KQ][k_KQ_0 + 0*WARP_SIZE + threadIdx.x] = __low2float(tmp); KV_tmp[i_KQ][k_KQ_0 + 1*WARP_SIZE + threadIdx.x] = __high2float(tmp); } @@ -231,8 +210,9 @@ static __global__ void flash_attn_tile_ext_f32( for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; - KV_tmp2[k*(D/2) + i].x = __low2float(V_h2[(k_VKQ_0 + k)*stride_KV2 + i]); - KV_tmp2[k*(D/2) + i].y = __high2float(V_h2[(k_VKQ_0 + k)*stride_KV2 + i]); + const half2 tmp = V_h2[int64_t(k_VKQ_0 + k)*stride_KV2 + i]; + KV_tmp2[k*(D/2) + i].x = __low2float(tmp); + KV_tmp2[k*(D/2) + i].y = __high2float(tmp); } } @@ -312,7 +292,6 @@ static __global__ void flash_attn_tile_ext_f32( GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); - GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3); NO_DEVICE_CODE; #endif // FLASH_ATTN_AVAILABLE } diff --git a/ggml/src/ggml-cuda/fattn-vec-f16.cuh b/ggml/src/ggml-cuda/fattn-vec-f16.cuh index 018c404c2..8b6ad893e 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f16.cuh @@ -18,33 +18,13 @@ static __global__ void flash_attn_vec_ext_f16( const float m1, const uint32_t n_head_log2, const float logit_softcap, - const int ne00, - const int ne01, - const int ne02, - const int ne03, - const int ne10, - const int ne11, - const int ne12, - const int ne13, - const int ne31, - const int ne32, - const int ne33, - const int nb31, - const int nb32, - const int nb33, - const int nb01, - const int nb02, - const int nb03, - const int nb11, - const int nb12, - const int nb13, - const int nb21, - const int nb22, - const int nb23, - const int ne0, - const int ne1, - const int ne2, - const int ne3) { + const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, + const int32_t nb01, const int32_t nb02, const int32_t nb03, + const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, + const int32_t nb11, const int32_t nb12, const int64_t nb13, + const int32_t nb21, const int32_t nb22, const int64_t nb23, + const int32_t ne31, const int32_t ne32, const int32_t ne33, + const int32_t nb31, const int32_t nb32, const int64_t nb33) { #if defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE) // Skip unused kernel variants for faster compilation: @@ -191,13 +171,16 @@ static __global__ void flash_attn_vec_ext_f16( half2 VKQ[ncols] = {{0.0f, 0.0f}}; + K += blockIdx.y*D * nb11; + V += blockIdx.y*D * nb21; + maskh += blockIdx.y*D; for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) { // Calculate KQ tile and keep track of new maximum KQ values: if (mask) { #pragma unroll for (int j = 0; j < ncols; ++j) { - maskh_shared[j*D + tid] = slopeh*maskh[j*ne11 + k_VKQ_0 + tid]; + maskh_shared[j*D + tid] = slopeh*maskh[j*ne11 + tid]; } __syncthreads(); @@ -244,7 +227,7 @@ static __global__ void flash_attn_vec_ext_f16( #pragma unroll for (int j = 0; j < ncols; ++j) { - half sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_h2[j], Q_i32[j], Q_ds[j]); + half sum = vec_dot_KQ(K + i_KQ*nb11, Q_h2[j], Q_i32[j], Q_ds[j]); sum = warp_reduce_sum((float)sum); if (use_logit_softcap) { @@ -300,14 +283,18 @@ static __global__ void flash_attn_vec_ext_f16( } half2 V_k; - reinterpret_cast(V_k.x) = dequantize_1_v(V + (k_VKQ_0 + k0 + 0)*nb21, tid); - reinterpret_cast(V_k.y) = dequantize_1_v(V + (k_VKQ_0 + k0 + 1)*nb21, tid); + reinterpret_cast(V_k.x) = dequantize_1_v(V + (k0 + 0)*nb21, tid); + reinterpret_cast(V_k.y) = dequantize_1_v(V + (k0 + 1)*nb21, tid); #pragma unroll for (int j = 0; j < ncols; ++j) { VKQ[j] += V_k*KQ2[j*(D/2) + k0/2]; } } + K += gridDim.y*D * nb11; + V += gridDim.y*D * nb21; + maskh += gridDim.y*D; + __syncthreads(); } @@ -351,8 +338,7 @@ static __global__ void flash_attn_vec_ext_f16( GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); - GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); - GGML_UNUSED(ne2); GGML_UNUSED(ne3); + GGML_UNUSED(nb23); NO_DEVICE_CODE; #endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE) } diff --git a/ggml/src/ggml-cuda/fattn-vec-f32.cuh b/ggml/src/ggml-cuda/fattn-vec-f32.cuh index 405b6f510..6a4bdc0ff 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f32.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f32.cuh @@ -18,33 +18,13 @@ static __global__ void flash_attn_vec_ext_f32( const float m1, const uint32_t n_head_log2, const float logit_softcap, - const int ne00, - const int ne01, - const int ne02, - const int ne03, - const int ne10, - const int ne11, - const int ne12, - const int ne13, - const int ne31, - const int ne32, - const int ne33, - const int nb31, - const int nb32, - const int nb33, - const int nb01, - const int nb02, - const int nb03, - const int nb11, - const int nb12, - const int nb13, - const int nb21, - const int nb22, - const int nb23, - const int ne0, - const int ne1, - const int ne2, - const int ne3) { + const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, + const int32_t nb01, const int32_t nb02, const int32_t nb03, + const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, + const int32_t nb11, const int32_t nb12, const int64_t nb13, + const int32_t nb21, const int32_t nb22, const int64_t nb23, + const int32_t ne31, const int32_t ne32, const int32_t ne33, + const int32_t nb31, const int32_t nb32, const int64_t nb33) { #ifdef FLASH_ATTN_AVAILABLE // Skip unused kernel variants for faster compilation: @@ -59,8 +39,7 @@ static __global__ void flash_attn_vec_ext_f32( GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); - GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); - GGML_UNUSED(ne2); GGML_UNUSED(ne3); + GGML_UNUSED(nb23); NO_DEVICE_CODE; return; } @@ -198,13 +177,16 @@ static __global__ void flash_attn_vec_ext_f32( float VKQ[ncols] = {0.0f}; + K += blockIdx.y*D * nb11; + V += blockIdx.y*D * nb21; + maskh += blockIdx.y*D; for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) { // Calculate KQ tile and keep track of new maximum KQ values: if (mask) { #pragma unroll for (int j = 0; j < ncols; ++j) { - maskf_shared[j*D + tid] = slope*__half2float(maskh[j*ne11 + k_VKQ_0 + tid]); + maskf_shared[j*D + tid] = slope*__half2float(maskh[j*ne11 + tid]); } __syncthreads(); @@ -246,7 +228,7 @@ static __global__ void flash_attn_vec_ext_f32( #pragma unroll for (int j = 0; j < ncols; ++j) { - float sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_f2[j], Q_i32[j], Q_ds[j]); + float sum = vec_dot_KQ(K + i_KQ*nb11, Q_f2[j], Q_i32[j], Q_ds[j]); sum = warp_reduce_sum(sum); if (use_logit_softcap) { @@ -297,13 +279,17 @@ static __global__ void flash_attn_vec_ext_f32( break; } - const float V_ki = dequantize_1_v(V + (k_VKQ_0 + k)*nb21, tid); + const float V_ki = dequantize_1_v(V + k*nb21, tid); #pragma unroll for (int j = 0; j < ncols; ++j) { VKQ[j] += V_ki*KQ[j*D + k]; } } + K += gridDim.y*D * nb11; + V += gridDim.y*D * nb21; + maskh += gridDim.y*D; + __syncthreads(); } @@ -348,7 +334,6 @@ static __global__ void flash_attn_vec_ext_f32( GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); - GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3); NO_DEVICE_CODE; #endif // FLASH_ATTN_AVAILABLE } diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ggml/src/ggml-cuda/fattn-wmma-f16.cu index 5353758f6..2d1af5efc 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cu +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cu @@ -37,33 +37,13 @@ static __global__ void flash_attn_ext_f16( const float m1, const uint32_t n_head_log2, const float logit_softcap, - const int ne00, - const int ne01, - const int ne02, - const int ne03, - const int ne10, - const int ne11, - const int ne12, - const int ne13, - const int ne31, - const int ne32, - const int ne33, - const int nb31, - const int nb32, - const int nb33, - const int nb01, - const int nb02, - const int nb03, - const int nb11, - const int nb12, - const int nb13, - const int nb21, - const int nb22, - const int nb23, - const int ne0, - const int ne1, - const int ne2, - const int ne3) { + const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, + const int32_t nb01, const int32_t nb02, const int32_t nb03, + const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, + const int32_t nb11, const int32_t nb12, const int64_t nb13, + const int32_t nb21, const int32_t nb22, const int64_t nb23, + const int32_t ne31, const int32_t ne32, const int32_t ne33, + const int32_t nb31, const int32_t nb32, const int64_t nb33) { #if !defined(GGML_HIP_NO_ROCWMMA_FATTN) && defined(FLASH_ATTN_AVAILABLE) && ((__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || __CUDA_ARCH__ == GGML_CUDA_CC_TURING) || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE))) // Skip unused kernel variants for faster compilation: if (use_logit_softcap && !(D == 128 || D == 256)) { @@ -197,7 +177,7 @@ static __global__ void flash_attn_ext_f16( #pragma unroll for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) { frag_a_K K_a; - wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV); + wmma::load_matrix_sync(K_a, K_h + int64_t(k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV); #pragma unroll for (int j = 0; j < ncols/frag_n; ++j) { wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]); @@ -344,7 +324,7 @@ static __global__ void flash_attn_ext_f16( const int k = k0 + (threadIdx.y % VKQ_ratio)*16; frag_a_V v_a; - wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV); + wmma::load_matrix_sync(v_a, V_h + int64_t(k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV); #pragma unroll for (int j = 0; j < ncols/frag_n; ++j) { wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]); @@ -451,7 +431,6 @@ static __global__ void flash_attn_ext_f16( GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); - GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3); NO_DEVICE_CODE; #endif // defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE))) } diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 862eaabd7..07a7a6dfb 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -280,22 +280,12 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size; const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV); - if (GGML_CUDA_CC_IS_AMD(cc)) { #if defined(GGML_HIP_ROCWMMA_FATTN) - if (fp16_mma_available(cc)) { - ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst); - return; - } -#endif // defined(GGML_HIP_ROCWMMA_FATTN) - - // On AMD the tile kernels perform poorly, use the vec kernel instead: - if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) { - ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); - } else { - ggml_cuda_flash_attn_ext_vec_f32(ctx, dst); - } + if (GGML_CUDA_CC_IS_AMD(cc) && fp16_mma_available(cc)) { + ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst); return; } +#endif // defined(GGML_HIP_ROCWMMA_FATTN) if (!fast_fp16_available(cc)) { if (Q->ne[1] <= 8 || Q->ne[0] == 256) { diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 8d87f04db..233b0aa5c 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -57,6 +57,7 @@ bool g_mul_mat_q = true; #include #include #include +#include #include #include #include @@ -2770,6 +2771,39 @@ static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) { } #endif +static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list ops) { + if (!ggml_can_fuse(cgraph, node_idx, ops)) { + return false; + } + + if (ops.size() == 2 && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) { + const ggml_tensor *rms_norm = cgraph->nodes[node_idx]; + const ggml_tensor *mul = cgraph->nodes[node_idx+1]; + + GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(rms_norm->type == GGML_TYPE_F32); + + //rms norm only supports F32 + if (mul->src[0]->type != GGML_TYPE_F32 || + mul->src[1]->type != GGML_TYPE_F32 || + mul->type != GGML_TYPE_F32) { + return false; + } + + //if rms norm is the B operand, then we don't handle broadcast + if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm->src[1])) { + return false; + } + + //rms_norm kernel assumes contigous rows + if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) { + return false; + } + } + + return true; +} + static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, bool & graph_evaluated_or_captured, bool & use_cuda_graph, bool & cuda_graph_update_required) { // flag used to determine whether it is an integrated_gpu @@ -2779,6 +2813,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx // Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph. // With the use of CUDA graphs, the execution will be performed by the graph launch. if (!use_cuda_graph || cuda_graph_update_required) { + for (int i = 0; i < cgraph->n_nodes; i++) { ggml_tensor * node = cgraph->nodes[i]; @@ -2786,6 +2821,12 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx continue; } + static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr); + if (!disable_fusion && ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { + ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i+1]); + i++; + continue; + } #ifndef NDEBUG assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device)); for (int j = 0; j < GGML_MAX_SRC; j++) { diff --git a/ggml/src/ggml-cuda/norm.cu b/ggml/src/ggml-cuda/norm.cu index 0020dbcec..bddcca51b 100644 --- a/ggml/src/ggml-cuda/norm.cu +++ b/ggml/src/ggml-cuda/norm.cu @@ -104,10 +104,12 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr } } -template +template static __global__ void rms_norm_f32( const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel, - const int64_t stride_sample, const float eps) { + const int64_t stride_sample, const float eps, const float * mul = nullptr, const int64_t mul_stride_row = 0, + const int64_t mul_stride_channel = 0, const int64_t mul_stride_sample = 0, const int mul_ncols = 0, + const int mul_nrows = 0, const int mul_nchannels = 0, const int mul_nsamples = 0) { const int nrows = gridDim.x; const int nchannels = gridDim.y; @@ -119,6 +121,13 @@ static __global__ void rms_norm_f32( x += sample*stride_sample + channel*stride_channel + row*stride_row; dst += ((sample*nchannels + channel)*nrows + row)*ncols; + if constexpr (do_multiply) { + const int mul_row = row % mul_nrows; + const int mul_channel = channel % mul_nchannels; + const int mul_sample = sample % mul_nsamples; + mul += mul_sample*mul_stride_sample + mul_channel*mul_stride_channel + mul_row*mul_stride_row; + } + float tmp = 0.0f; // partial sum for thread in warp for (int col = tid; col < ncols; col += block_size) { @@ -145,7 +154,12 @@ static __global__ void rms_norm_f32( const float scale = rsqrtf(mean + eps); for (int col = tid; col < ncols; col += block_size) { - dst[col] = scale * x[col]; + if constexpr (do_multiply) { + const int mul_col = col % mul_ncols; + dst[col] = scale * x[col] * mul[mul_col]; + } else { + dst[col] = scale * x[col]; + } } } @@ -310,10 +324,30 @@ static void rms_norm_f32_cuda( const dim3 blocks_num(nrows, nchannels, nsamples); if (ncols < 1024) { const dim3 block_dims(WARP_SIZE, 1, 1); - rms_norm_f32<<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); + rms_norm_f32<<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); } else { const dim3 block_dims(1024, 1, 1); - rms_norm_f32<1024><<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); + rms_norm_f32<1024, false><<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); + } +} + +static void rms_norm_mul_f32_cuda( + const float * x, const float * mul, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples, + const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, + const int64_t mul_stride_row, const int64_t mul_stride_channel, const int64_t mul_stride_sample, + const int mul_ncols, const int mul_nrows, const int mul_nchannels, const int mul_nsamples, + const float eps, cudaStream_t stream) { + const dim3 blocks_num(nrows, nchannels, nsamples); + if (mul == nullptr) { + rms_norm_f32_cuda(x, dst, ncols, nrows, nchannels, nsamples, stride_row, stride_channel, stride_sample, eps, stream); + return; + } + if (ncols < 1024) { + const dim3 block_dims(WARP_SIZE, 1, 1); + rms_norm_f32<<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, mul_stride_sample, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples); + } else { + const dim3 block_dims(1024, 1, 1); + rms_norm_f32<1024, true><<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, mul_stride_sample, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples); } } @@ -407,6 +441,59 @@ void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { rms_norm_f32_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream); } +void ggml_cuda_op_rms_norm_fused(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * mul_tensor) { + const ggml_tensor * rms_norm_src = (ggml_tensor *) dst->src[0]; + float eps = 0.0f; + + memcpy(&eps, dst->op_params, sizeof(float)); + + const float * src0_d = (const float *) rms_norm_src->data; + const float * mul_d = nullptr; + const ggml_tensor * mul_src = nullptr; + + if (mul_tensor->src[0] == dst) { + mul_d = (float *) mul_tensor->src[1]->data; + mul_src = mul_tensor->src[1]; + } else if(mul_tensor->src[1] == dst) { + mul_d = (float *) mul_tensor->src[0]->data; + mul_src = mul_tensor->src[0]; + } else { + GGML_ASSERT(false); + } + + float * dst_d = (float *) mul_tensor->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(rms_norm_src->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(mul_tensor->type == GGML_TYPE_F32); + GGML_ASSERT(eps >= 0.0f); + + const int64_t ne00 = rms_norm_src->ne[0]; + const int64_t ne01 = rms_norm_src->ne[1]; + const int64_t ne02 = rms_norm_src->ne[2]; + const int64_t ne03 = rms_norm_src->ne[3]; + + const size_t ts0 = ggml_type_size(rms_norm_src->type); + GGML_ASSERT(rms_norm_src->nb[0] == ts0); + const int64_t s01 = rms_norm_src->nb[1] / ts0; + const int64_t s02 = rms_norm_src->nb[2] / ts0; + const int64_t s03 = rms_norm_src->nb[3] / ts0; + + const size_t ts_mul = ggml_type_size(mul_src->type); + GGML_ASSERT(mul_src->nb[0] == ts_mul); + const int64_t mul_s01 = mul_src->nb[1] / ts_mul; + const int64_t mul_s02 = mul_src->nb[2] / ts_mul; + const int64_t mul_s03 = mul_src->nb[3] / ts_mul; + + const int mul_ncols = mul_src->ne[0]; + const int mul_nrows = mul_src->ne[1]; + const int mul_nchannels = mul_src->ne[2]; + const int mul_nsamples = mul_src->ne[3]; + + rms_norm_mul_f32_cuda(src0_d, mul_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, mul_s01, mul_s02, mul_s03, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples, eps, stream); +} + void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * grad = dst->src[0]; // gradients const ggml_tensor * src0f = dst->src[1]; // src0 from forward pass diff --git a/ggml/src/ggml-cuda/norm.cuh b/ggml/src/ggml-cuda/norm.cuh index 706a5660a..7ea7bd4df 100644 --- a/ggml/src/ggml-cuda/norm.cuh +++ b/ggml/src/ggml-cuda/norm.cuh @@ -6,6 +6,8 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst); +void ggml_cuda_op_rms_norm_fused(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * mul_tensor); + void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_l2_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/vendors/musa.h b/ggml/src/ggml-cuda/vendors/musa.h index 937779a90..198963202 100644 --- a/ggml/src/ggml-cuda/vendors/musa.h +++ b/ggml/src/ggml-cuda/vendors/musa.h @@ -13,7 +13,7 @@ #define CUBLAS_OP_N MUBLAS_OP_N #define CUBLAS_OP_T MUBLAS_OP_T #define CUBLAS_STATUS_SUCCESS MUBLAS_STATUS_SUCCESS -#define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_MATH_MODE_DEFAULT +#define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_TENSOR_OP_MATH #define CUDA_R_16F MUSA_R_16F #define CUDA_R_16BF MUSA_R_16BF #define CUDA_R_32F MUSA_R_32F @@ -29,7 +29,7 @@ #define cublasSgemm mublasSgemm #define cublasStatus_t mublasStatus_t #define cublasOperation_t mublasOperation_t -#define cublasGetStatusString mublasStatus_to_string +#define cublasGetStatusString mublasGetStatusString #define cudaDataType_t musaDataType_t #define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer #define cudaDeviceDisablePeerAccess musaDeviceDisablePeerAccess diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index b267f394c..a91b5bbbf 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -1955,6 +1955,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex static int ggml_metal_encode_node( ggml_backend_t backend, int idx, + int idx_end, id encoder, struct ggml_metal_mem_pool * mem_pool) { struct ggml_backend_metal_context * ctx = backend->context; @@ -2181,7 +2182,9 @@ static int ggml_metal_encode_node( size_t offs_fuse; id id_fuse; - for (n_fuse = 0; n_fuse <= 6; ++n_fuse) { + // note: in metal, we sometimes encode the graph in parallel so we have to avoid fusing nodes + // across splits. idx_end indicates the last node in the current split + for (n_fuse = 0; n_fuse <= 6 && idx + n_fuse + 1 < idx_end; ++n_fuse) { if (!ggml_can_fuse(gf, idx + n_fuse, ops + n_fuse, 2)) { break; } @@ -4288,7 +4291,7 @@ static int ggml_metal_encode_node( ops[1] = GGML_OP_MUL; ops[2] = GGML_OP_ADD; - for (n_fuse = 0; n_fuse <= 1; ++n_fuse) { + for (n_fuse = 0; n_fuse <= 1 && idx + n_fuse + 1 < idx_end; ++n_fuse) { if (!ggml_can_fuse(gf, idx + n_fuse, ops + n_fuse, 2)) { break; } @@ -6271,7 +6274,11 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) { [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]]; } - const int res = ggml_metal_encode_node(backend, idx, encoder, mem_pool); + const int res = ggml_metal_encode_node(backend, idx, node_end, encoder, mem_pool); + if (idx + res > node_end) { + GGML_ABORT("fusion error: nodes spanning multiple encoders have been fused. this indicates a bug in the fusion logic %s", + "https://github.com/ggml-org/llama.cpp/pull/14849"); + } if (should_capture) { [encoder popDebugGroup]; diff --git a/ggml/src/ggml-sycl/quants.hpp b/ggml/src/ggml-sycl/quants.hpp index 8b952db43..d0d5ac9a4 100644 --- a/ggml/src/ggml-sycl/quants.hpp +++ b/ggml/src/ggml-sycl/quants.hpp @@ -48,11 +48,11 @@ template <> struct block_q_t { }; static constexpr std::pair get_block_offset(const int block_index, const int /* nblocks */) { - return { block_index * (traits::qk / traits::qr), 0 }; + return { block_index * (QK4_0 / QR4_0), 0 }; } static constexpr std::pair get_d_offset(int nrows, int ncols, const int block_index) { - return { (ncols / traits::qr * nrows) + block_index * sizeof(ggml_half), 0 }; + return { (ncols / QR4_0 * nrows) + block_index * sizeof(ggml_half), 0 }; } static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; } @@ -71,14 +71,12 @@ template <> struct block_q_t { } static constexpr std::pair get_d_offset(int nrows, int ncols, const int block_index) { - auto nblocks = (nrows * (ncols / traits::qk)); - return { nblocks * (QK_K / 2), + auto nblocks = (nrows * (ncols / QK_K)); + return { nblocks * (QK_K / 2) + (block_index * K_SCALE_SIZE), (nblocks * QK_K / 2) + (nblocks * K_SCALE_SIZE) + (block_index * sizeof(ggml_half2)) }; } static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; } - - constexpr size_t get_total_qs_bytes(int nblocks) { return nblocks * QK_K / 2; } }; template <> struct block_q_t { @@ -90,22 +88,23 @@ template <> struct block_q_t { }; static constexpr std::pair get_block_offset(const int block_index, const int n_blocks) { - auto low_bits_index = block_index * (traits::qk / traits::qr); + auto low_bits_index = block_index * (QK_K / QR6_K); // the index of high bits it's after all low bits auto high_bits_index = n_blocks * (QK_K / 2) + (block_index * (QK_K / 4)); return { low_bits_index, high_bits_index }; } static constexpr std::pair get_d_offset(int nrows, int ncols, const int block_index) { - auto nblocks = (nrows * (ncols / traits::qk)); + auto nblocks = (nrows * (ncols / QK_K)); auto total_qs_bytes = nblocks * (QK_K / 2) + nblocks * (QK_K / 4); auto block_scales = total_qs_bytes + block_index * (QK_K / 16); - auto sb_scale = total_qs_bytes + nblocks * (QK_K / 16); + auto sb_scale = total_qs_bytes + nblocks * (QK_K / 16) + block_index * sizeof(ggml_half); return { block_scales, sb_scale }; } static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; } }; + } // namespace ggml_sycl_reordered #endif // GGML_SYCL_QUANTS_HPP diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index b14a04760..17958aa72 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -10272,7 +10272,7 @@ static bool ggml_vk_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, st } // if rms_norm is the B operand, then we don't handle broadcast if (rms_norm == mul->src[1] && - mul->src[0]->ne[1] != rms_norm->ne[1]) { + !ggml_are_same_shape(mul->src[0], rms_norm)) { return false; } // rms_norm shader assumes contiguous rows diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp index 6428ca7ba..bdd7db2d6 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp @@ -50,8 +50,14 @@ void main() { const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1)); if (do_multiply) { - [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) { - data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + col])); + if (ncols > p.ne10) { + [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) { + data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + fastmod(col, p.ne10)])); + } + } else { + [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) { + data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + col])); + } } } else { [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) { diff --git a/gguf-py/gguf/metadata.py b/gguf-py/gguf/metadata.py index e807f4346..67efedbdb 100644 --- a/gguf-py/gguf/metadata.py +++ b/gguf-py/gguf/metadata.py @@ -144,6 +144,10 @@ class Metadata: # Quick hack to fix the Norway problem # https://hitchdev.com/strictyaml/why/implicit-typing-removed/ yaml_content = yaml_content.replace("- no\n", "- \"no\"\n") + # yaml should use 2 spaces insted of tab + # this issue has came up with the Qwen/Qwen3-235B-A22B-Instruct-2507 model card + # (I've also sent a pr tp fix the modelcard too) + yaml_content = yaml_content.replace("\t", " ") if yaml_content: data = yaml.safe_load(yaml_content) diff --git a/include/llama.h b/include/llama.h index 1aee4f2d9..b46416b13 100644 --- a/include/llama.h +++ b/include/llama.h @@ -959,6 +959,7 @@ extern "C" { // in the order they have appeared in the batch. // Rows: number of tokens for which llama_batch.logits[i] != 0 // Cols: n_vocab + // TODO: deprecate in favor of llama_get_logits_ith() (ref: https://github.com/ggml-org/llama.cpp/pull/14853#issuecomment-3113143522) LLAMA_API float * llama_get_logits(struct llama_context * ctx); // Logits for the ith token. For positive indices, Equivalent to: @@ -973,6 +974,7 @@ extern "C" { // in the order they have appeared in the batch. // shape: [n_outputs*n_embd] // Otherwise, returns NULL. + // TODO: deprecate in favor of llama_get_embeddings_ith() (ref: https://github.com/ggml-org/llama.cpp/pull/14853#issuecomment-3113143522) LLAMA_API float * llama_get_embeddings(struct llama_context * ctx); // Get the embeddings for the ith token. For positive indices, Equivalent to: diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 814ac93a6..062a99776 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -1933,12 +1933,6 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, } }, - { - LLM_ARCH_UNKNOWN, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - }, - }, { LLM_ARCH_DREAM, { @@ -1956,6 +1950,12 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, + { + LLM_ARCH_UNKNOWN, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + }, + }, }; static const std::map LLM_TENSOR_INFOS = { diff --git a/src/llama-chat.cpp b/src/llama-chat.cpp index 80072ad27..d34bb2687 100644 --- a/src/llama-chat.cpp +++ b/src/llama-chat.cpp @@ -718,10 +718,9 @@ int32_t llm_chat_apply_template( } ss << message->content << "<|im_end|>"; - - if (add_ass) { - ss << "<|im_assistant|>assistant<|im_middle|>"; - } + } + if (add_ass) { + ss << "<|im_assistant|>assistant<|im_middle|>"; } } else { // template not supported diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 745e650a6..99adb3675 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -508,12 +508,16 @@ enum llama_pooling_type llama_context::pooling_type() const { } float * llama_context::get_logits() { + output_reorder(); + return logits; } float * llama_context::get_logits_ith(int32_t i) { int64_t j = -1; + output_reorder(); + try { if (logits == nullptr) { throw std::runtime_error("no logits"); @@ -550,12 +554,16 @@ float * llama_context::get_logits_ith(int32_t i) { } float * llama_context::get_embeddings() { + output_reorder(); + return embd; } float * llama_context::get_embeddings_ith(int32_t i) { int64_t j = -1; + output_reorder(); + try { if (embd == nullptr) { throw std::runtime_error("no embeddings"); @@ -970,6 +978,7 @@ int llama_context::decode(const llama_batch & batch_inp) { // TODO: this clear of the buffer can easily be forgotten - need something better embd_seq.clear(); + output_swaps.clear(); bool did_optimize = false; @@ -1189,9 +1198,6 @@ int llama_context::decode(const llama_batch & batch_inp) { // make the outputs have the same order they had in the user-provided batch // note: this is mostly relevant for recurrent models atm if (!sorted_output) { - const uint32_t n_vocab = model.vocab.n_tokens(); - const uint64_t n_embd = model.hparams.n_embd; - GGML_ASSERT((size_t) n_outputs == out_ids.size()); // TODO: is there something more efficient which also minimizes swaps? @@ -1207,16 +1213,9 @@ int llama_context::decode(const llama_batch & batch_inp) { continue; } std::swap(out_ids[i], out_ids[j_min]); - if (logits_size > 0) { - for (uint32_t k = 0; k < n_vocab; k++) { - std::swap(logits[i*n_vocab + k], logits[j_min*n_vocab + k]); - } - } - if (embd_size > 0) { - for (uint32_t k = 0; k < n_embd; k++) { - std::swap(embd[i*n_embd + k], embd[j_min*n_embd + k]); - } - } + + // remember the swaps and apply them lazily upon logits/embeddings access + output_swaps.push_back({ i, j_min }); } std::fill(output_ids.begin(), output_ids.end(), -1); @@ -1307,6 +1306,30 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { return n_outputs_max; } +void llama_context::output_reorder() { + const uint32_t n_vocab = model.vocab.n_tokens(); + const uint64_t n_embd = model.hparams.n_embd; + + for (uint32_t s = 0; s < output_swaps.size(); ++s) { + const uint32_t i0 = output_swaps[s].i0; + const uint32_t i1 = output_swaps[s].i1; + + if (logits_size > 0) { + for (uint32_t k = 0; k < n_vocab; k++) { + std::swap(logits[i0*n_vocab + k], logits[i1*n_vocab + k]); + } + } + + if (embd_size > 0) { + for (uint32_t k = 0; k < n_embd; k++) { + std::swap(embd[i0*n_embd + k], embd[i1*n_embd + k]); + } + } + } + + output_swaps.clear(); +} + // // graph // diff --git a/src/llama-context.h b/src/llama-context.h index 1601ac682..fdbe61207 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -181,6 +181,8 @@ private: // Returns max number of outputs for which space was reserved. uint32_t output_reserve(int32_t n_outputs); + void output_reorder(); + // // graph // @@ -250,6 +252,13 @@ private: std::vector output_ids; // map batch token positions to ids of the logits and embd buffers + struct swap_info { + uint32_t i0; + uint32_t i1; + }; + + std::vector output_swaps; + ggml_backend_sched_ptr sched; ggml_backend_t backend_cpu = nullptr; diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp index 25d62bb9a..7305d99db 100644 --- a/src/llama-memory-recurrent.cpp +++ b/src/llama-memory-recurrent.cpp @@ -768,6 +768,8 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std:: // Iterate and write all the keys first, each row is a cell // Get whole range at a time for (uint32_t il = 0; il < n_layer; ++il) { + // skip null layers (read_data will handle this by checking "r_l" and "s_l" for null) + if (r_l[il] == nullptr) continue; // Write key type const int32_t r_type_i = (int32_t)r_l[il]->type; @@ -787,6 +789,8 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std:: if (!s_trans) { for (uint32_t il = 0; il < n_layer; ++il) { + // skip null layers (read_data will handle this by checking "r_l" and "s_l" for null) + if (s_l[il] == nullptr) continue; // Write value type const int32_t s_type_i = (int32_t)s_l[il]->type; @@ -807,6 +811,9 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std:: // When v is transposed, we also need the element size and get the element ranges from each row const uint32_t mem_size = size; for (uint32_t il = 0; il < n_layer; ++il) { + // skip null layers (read_data will handle this by checking "r_l" and "s_l" for null) + if (s_l[il] == nullptr) continue; + const uint32_t n_embd_s = hparams.n_embd_s(); // Write value type @@ -951,6 +958,8 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block for (uint32_t il = 0; il < n_layer; ++il) { + // skip null layers + if (r_l[il] == nullptr) continue; // Read type of key int32_t r_type_i_ref; @@ -978,11 +987,14 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell if (!s_trans) { for (uint32_t il = 0; il < n_layer; ++il) { + // skip null layers + if (s_l[il] == nullptr) continue; // Read type of value int32_t s_type_i_ref; io.read_to(&s_type_i_ref, sizeof(s_type_i_ref)); const int32_t s_type_i = (int32_t)s_l[il]->type; + if (s_type_i != s_type_i_ref) { LLAMA_LOG_ERROR("%s: mismatched s type (%d != %d, layer %d)\n", __func__, s_type_i, s_type_i_ref, il); return false; @@ -1005,6 +1017,9 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell } else { // For each layer, read the values for each cell (transposed) for (uint32_t il = 0; il < n_layer; ++il) { + // skip null layers + if (s_l[il] == nullptr) continue; + const uint32_t n_embd_s = hparams.n_embd_s(); // Read type of value diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 9a0839895..c98453520 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -651,6 +651,9 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale); ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); + // MiniCPM uses rope by default, unlike Granite which uses it as a switch + hparams.rope_finetuned = true; + switch (hparams.n_layer) { case 52: type = LLM_TYPE_1B; break; case 40: type = LLM_TYPE_2B; break; @@ -1549,7 +1552,11 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_TOKEN_SHIFT_COUNT, hparams.token_shift_count, false); switch (hparams.n_layer) { - case 12: type = LLM_TYPE_190M; break; + case 12: + switch (hparams.n_embd) { + case 768: type = LLM_TYPE_190M; break; + default: type = LLM_TYPE_UNKNOWN; + } break; case 24: switch (hparams.n_embd) { case 1024: type = LLM_TYPE_450M; break; @@ -1562,7 +1569,17 @@ void llama_model::load_hparams(llama_model_loader & ml) { case 3584: type = LLM_TYPE_7B; break; default: type = LLM_TYPE_UNKNOWN; } break; - case 32: type = LLM_TYPE_2_9B; break; // RWKV-7-World + case 32: + switch (hparams.n_embd) { + case 2560: type = LLM_TYPE_2_9B; break; + case 4096: type = LLM_TYPE_7B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 61: + switch (hparams.n_embd) { + case 4096: type = LLM_TYPE_14B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; default: type = LLM_TYPE_UNKNOWN; } } break;