Merge branch 'upstream' into concedo_experimental

# Conflicts:
#	.devops/musa.Dockerfile
#	.github/workflows/build.yml
#	.github/workflows/close-issue.yml
#	ci/README.md
#	docs/build.md
#	docs/docker.md
#	ggml/CMakeLists.txt
#	ggml/cmake/ggml-config.cmake.in
#	ggml/src/ggml-cann/aclnn_ops.cpp
#	ggml/src/ggml-cann/aclnn_ops.h
#	ggml/src/ggml-cann/ggml-cann.cpp
#	ggml/src/ggml-cpu/CMakeLists.txt
#	ggml/src/ggml-cuda/fattn-wmma-f16.cu
#	ggml/src/ggml-musa/CMakeLists.txt
#	ggml/src/ggml-rpc/ggml-rpc.cpp
#	ggml/src/ggml-sycl/ggml-sycl.cpp
#	ggml/src/ggml-sycl/vecdotq.hpp
#	scripts/sync-ggml.last
#	tests/test-backend-ops.cpp
#	tools/imatrix/README.md
#	tools/imatrix/imatrix.cpp
This commit is contained in:
Concedo 2025-07-25 19:53:13 +08:00
commit 0fcfbdb93c
33 changed files with 501 additions and 348 deletions

View file

@ -2657,6 +2657,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.i_chunk = value; params.i_chunk = value;
} }
).set_examples({LLAMA_EXAMPLE_IMATRIX})); ).set_examples({LLAMA_EXAMPLE_IMATRIX}));
add_opt(common_arg(
{"--show-statistics"},
string_format("show imatrix statistics and then exit (default: %s)", params.show_statistics ? "true" : "false"),
[](common_params & params) {
params.show_statistics = true;
}
).set_examples({LLAMA_EXAMPLE_IMATRIX}));
add_opt(common_arg( add_opt(common_arg(
{"--parse-special"}, {"--parse-special"},
string_format("prase special tokens (chat, tool, etc) (default: %s)", params.parse_special ? "true" : "false"), string_format("prase special tokens (chat, tool, etc) (default: %s)", params.parse_special ? "true" : "false"),

View file

@ -430,6 +430,7 @@ struct common_params {
bool process_output = false; // collect data for the output tensor bool process_output = false; // collect data for the output tensor
bool compute_ppl = true; // whether to compute perplexity bool compute_ppl = true; // whether to compute perplexity
bool show_statistics = false; // show imatrix statistics per tensor
bool parse_special = false; // whether to parse special tokens during imatrix tokenization bool parse_special = false; // whether to parse special tokens during imatrix tokenization
// cvector-generator params // cvector-generator params

View file

@ -6486,7 +6486,7 @@ class JaisModel(TextModel):
self.gguf_writer.add_max_alibi_bias(self.max_alibi_bias) self.gguf_writer.add_max_alibi_bias(self.max_alibi_bias)
@ModelBase.register("Glm4ForCausalLM") @ModelBase.register("Glm4ForCausalLM", "Glm4vForConditionalGeneration")
class Glm4Model(TextModel): class Glm4Model(TextModel):
model_arch = gguf.MODEL_ARCH.GLM4 model_arch = gguf.MODEL_ARCH.GLM4
@ -6508,7 +6508,8 @@ class Glm4Model(TextModel):
def set_gguf_parameters(self): def set_gguf_parameters(self):
super().set_gguf_parameters() super().set_gguf_parameters()
rope_dim = self.hparams["head_dim"] if (rope_dim := self.hparams.get("head_dim")) is None:
rope_dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.hparams.get("partial_rotary_factor", 0.5))) self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.hparams.get("partial_rotary_factor", 0.5)))
rope_scaling = self.hparams.get("rope_scaling") or {} rope_scaling = self.hparams.get("rope_scaling") or {}
if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling: if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling:
@ -6516,6 +6517,13 @@ class Glm4Model(TextModel):
self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"]) self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"])
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if name.startswith("model.visual."): # ignore visual part of Glm4v
return []
elif name.startswith("model.language_model."):
name = name.replace("language_model.", "") # for Glm4v
return super().modify_tensors(data_torch, name, bid)
@ModelBase.register("GlmForCausalLM", "ChatGLMModel", "ChatGLMForConditionalGeneration") @ModelBase.register("GlmForCausalLM", "ChatGLMModel", "ChatGLMForConditionalGeneration")
class ChatGLMModel(TextModel): class ChatGLMModel(TextModel):

View file

@ -647,6 +647,7 @@ struct ggml_backend_sched {
// pipeline parallelism support // pipeline parallelism support
int n_copies; int n_copies;
int cur_copy; int cur_copy;
int next_copy;
ggml_backend_event_t events[GGML_SCHED_MAX_BACKENDS][GGML_SCHED_MAX_COPIES]; ggml_backend_event_t events[GGML_SCHED_MAX_BACKENDS][GGML_SCHED_MAX_COPIES];
struct ggml_tensor * graph_inputs[GGML_SCHED_MAX_SPLIT_INPUTS]; struct ggml_tensor * graph_inputs[GGML_SCHED_MAX_SPLIT_INPUTS];
int n_graph_inputs; int n_graph_inputs;
@ -1439,8 +1440,6 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
} }
} }
sched->cur_copy = (sched->cur_copy + 1) % sched->n_copies;
return GGML_STATUS_SUCCESS; return GGML_STATUS_SUCCESS;
} }
@ -1541,10 +1540,10 @@ void ggml_backend_sched_reset(ggml_backend_sched_t sched) {
bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph) { bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph) {
GGML_ASSERT((int)sched->hash_set.size >= measure_graph->n_nodes + measure_graph->n_leafs); GGML_ASSERT((int)sched->hash_set.size >= measure_graph->n_nodes + measure_graph->n_leafs);
ggml_backend_sched_split_graph(sched, measure_graph);
ggml_backend_sched_synchronize(sched); ggml_backend_sched_synchronize(sched);
ggml_backend_sched_split_graph(sched, measure_graph);
if (!ggml_gallocr_reserve_n(sched->galloc, &sched->graph, sched->node_backend_ids, sched->leaf_backend_ids)) { if (!ggml_gallocr_reserve_n(sched->galloc, &sched->graph, sched->node_backend_ids, sched->leaf_backend_ids)) {
return false; return false;
} }
@ -1556,6 +1555,10 @@ bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph *
bool ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) { bool ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
GGML_ASSERT((int)sched->hash_set.size >= graph->n_nodes + graph->n_leafs); GGML_ASSERT((int)sched->hash_set.size >= graph->n_nodes + graph->n_leafs);
GGML_ASSERT(!sched->is_alloc);
sched->cur_copy = sched->next_copy;
sched->next_copy = (sched->next_copy + 1) % sched->n_copies;
ggml_backend_sched_split_graph(sched, graph); ggml_backend_sched_split_graph(sched, graph);
@ -1596,7 +1599,7 @@ void ggml_backend_sched_synchronize(ggml_backend_sched_t sched) {
// if the graph is not already allocated, always use copy 0 after a synchronization // if the graph is not already allocated, always use copy 0 after a synchronization
// this ensures that during generation the same copy is used every time, // this ensures that during generation the same copy is used every time,
// which avoids changes in the graph that could cause CUDA or other graphs to be disabled // which avoids changes in the graph that could cause CUDA or other graphs to be disabled
sched->cur_copy = 0; sched->next_copy = 0;
} }
} }

View file

@ -544,7 +544,7 @@ void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, i
__m128 max4 = __lsx_vfmax_s( lasx_extractf128( max_abs, 1 ), lasx_extractf128( max_abs, 0) ); __m128 max4 = __lsx_vfmax_s( lasx_extractf128( max_abs, 1 ), lasx_extractf128( max_abs, 0) );
max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vpickod_d((__m128i) max4, (__m128i)max4 ) ); max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vpickod_d((__m128i) max4, (__m128i)max4 ) );
__m128 tmp = max4; __m128 tmp = max4;
max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vextrins_w((__m128i)tmp, (__m128i)max4, 0x10 )); max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vextrins_w((__m128i)tmp, (__m128i)max4, 0x1 ));
const float max_scalar = ((v4f32)max4)[0]; const float max_scalar = ((v4f32)max4)[0];
// Quantize these floats // Quantize these floats

View file

@ -14,7 +14,6 @@
#include <cmath> #include <cmath>
#include <cstring> #include <cstring>
#include <cassert> #include <cassert>
#include <cstdlib> // for qsort
#include <cstdio> // for GGML_ASSERT #include <cstdio> // for GGML_ASSERT
#include "repack.h" #include "repack.h"

View file

@ -769,7 +769,7 @@ struct ggml_tensor_extra_gpu {
}; };
#if (defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS)) #if (defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS)) || defined(GGML_MUSA_GRAPHS)
#define USE_CUDA_GRAPH #define USE_CUDA_GRAPH
#endif #endif

View file

@ -6,24 +6,33 @@
#define CUDA_Q8_0_NE_ALIGN 2048 #define CUDA_Q8_0_NE_ALIGN 2048
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t> template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k) { static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y,
const int64_t i = (int64_t)2*(blockDim.x*blockIdx.x + threadIdx.x); const int64_t ne00, const int64_t ne01, const int64_t ne02,
const int64_t s01, const int64_t s02, const int64_t s03) {
const int64_t i00 = 2 * (int64_t(blockDim.x)*blockIdx.x + threadIdx.x);
if (i >= k) { if (i00 >= ne00) {
return; return;
} }
const int64_t ib = i/qk; // block index const int64_t i01 = blockIdx.y;
const int64_t iqs = (i%qk)/qr; // quant index const int64_t i02 = blockIdx.z % ne02;
const int64_t iybs = i - i%qk; // y block start index const int64_t i03 = blockIdx.z / ne02;
const int64_t ibx0 = i03*s03 + i02*s02 + i01*s01;
const int64_t ib = ibx0 + i00/qk; // block index
const int64_t iqs = (i00%qk)/qr; // quant index
const int64_t iybs = i00 - i00%qk; // y block start index
const int64_t y_offset = qr == 1 ? 1 : qk/2; const int64_t y_offset = qr == 1 ? 1 : qk/2;
// dequantize // dequantize
dfloat2 v; dfloat2 v;
dequantize_kernel(vx, ib, iqs, v); dequantize_kernel(vx, ib, iqs, v);
y[iybs + iqs + 0] = v.x; const int64_t iy0 = ((i03*ne02 + i02)*ne01 + i01)*ne00 + iybs + iqs;
y[iybs + iqs + y_offset] = v.y; y[iy0 + 0] = float(v.x);
y[iy0 + y_offset] = float(v.y);
} }
template <bool need_check> template <bool need_check>
@ -457,9 +466,17 @@ static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst
} }
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t> template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) { static void dequantize_block_cuda(const void * vx, dst_t * y,
const int num_blocks = (k + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE); const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
dequantize_block<qk, qr, dequantize_kernel><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k); const int64_t s01, const int64_t s02, const int64_t s03, cudaStream_t stream) {
const dim3 num_blocks((ne00 + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE), ne01, ne02*ne03);
dequantize_block<qk, qr, dequantize_kernel><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>
(vx, y, ne00, ne01, ne02, s01, s02, s03);
}
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
static void dequantize_block_cont_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) {
dequantize_block_cuda<qk, qr, dequantize_kernel, dst_t>(vx, y, k, 1, 1, 1, k/qk, k/qk, k/qk, stream);
} }
static void dequantize_block_q8_0_f16_cuda(const void * __restrict__ vx, half * __restrict__ y, const int64_t k, cudaStream_t stream) { static void dequantize_block_q8_0_f16_cuda(const void * __restrict__ vx, half * __restrict__ y, const int64_t k, cudaStream_t stream) {
@ -624,14 +641,14 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_1:
return dequantize_row_q4_1_cuda; return dequantize_row_q4_1_cuda;
case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_0:
return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>; return dequantize_block_cont_cuda<QK5_0, QR5_0, dequantize_q5_0>;
case GGML_TYPE_Q5_1: case GGML_TYPE_Q5_1:
return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>; return dequantize_block_cont_cuda<QK5_1, QR5_1, dequantize_q5_1>;
case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_0:
if (fp16_available(ggml_cuda_info().devices[ggml_cuda_get_device()].cc)) { if (fp16_available(ggml_cuda_info().devices[ggml_cuda_get_device()].cc)) {
return dequantize_block_q8_0_f16_cuda; return dequantize_block_q8_0_f16_cuda;
} }
return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>; return dequantize_block_cont_cuda<QK8_0, QR8_0, dequantize_q8_0>;
case GGML_TYPE_Q2_K: case GGML_TYPE_Q2_K:
return dequantize_row_q2_K_cuda; return dequantize_row_q2_K_cuda;
case GGML_TYPE_Q3_K: case GGML_TYPE_Q3_K:
@ -676,11 +693,11 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_1:
return dequantize_row_q4_1_cuda; return dequantize_row_q4_1_cuda;
case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_0:
return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>; return dequantize_block_cont_cuda<QK5_0, QR5_0, dequantize_q5_0>;
case GGML_TYPE_Q5_1: case GGML_TYPE_Q5_1:
return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>; return dequantize_block_cont_cuda<QK5_1, QR5_1, dequantize_q5_1>;
case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_0:
return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>; return dequantize_block_cont_cuda<QK8_0, QR8_0, dequantize_q8_0>;
case GGML_TYPE_Q2_K: case GGML_TYPE_Q2_K:
return dequantize_row_q2_K_cuda; return dequantize_row_q2_K_cuda;
case GGML_TYPE_Q3_K: case GGML_TYPE_Q3_K:
@ -722,6 +739,16 @@ to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type) {
switch (type) { switch (type) {
case GGML_TYPE_F32: case GGML_TYPE_F32:
return convert_unary_cuda<float>; return convert_unary_cuda<float>;
case GGML_TYPE_Q4_0:
return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
case GGML_TYPE_Q4_1:
return dequantize_block_cuda<QK4_1, QR4_1, dequantize_q4_1>;
case GGML_TYPE_Q5_0:
return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
case GGML_TYPE_Q5_1:
return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
case GGML_TYPE_Q8_0:
return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
case GGML_TYPE_BF16: case GGML_TYPE_BF16:
return convert_unary_cuda<nv_bfloat16>; return convert_unary_cuda<nv_bfloat16>;
default: default:
@ -733,6 +760,16 @@ to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type) {
switch (type) { switch (type) {
case GGML_TYPE_F32: case GGML_TYPE_F32:
return convert_unary_cuda<float, nv_bfloat16>; return convert_unary_cuda<float, nv_bfloat16>;
case GGML_TYPE_Q4_0:
return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
case GGML_TYPE_Q4_1:
return dequantize_block_cuda<QK4_1, QR4_1, dequantize_q4_1>;
case GGML_TYPE_Q5_0:
return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
case GGML_TYPE_Q5_1:
return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
case GGML_TYPE_Q8_0:
return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
case GGML_TYPE_F16: case GGML_TYPE_F16:
return convert_unary_cuda<half, nv_bfloat16>; return convert_unary_cuda<half, nv_bfloat16>;
default: default:
@ -744,6 +781,16 @@ to_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda(ggml_type type) {
switch (type) { switch (type) {
case GGML_TYPE_F16: case GGML_TYPE_F16:
return convert_unary_cuda<half, float>; return convert_unary_cuda<half, float>;
case GGML_TYPE_Q4_0:
return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
case GGML_TYPE_Q4_1:
return dequantize_block_cuda<QK4_1, QR4_1, dequantize_q4_1>;
case GGML_TYPE_Q5_0:
return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
case GGML_TYPE_Q5_1:
return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
case GGML_TYPE_Q8_0:
return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
case GGML_TYPE_BF16: case GGML_TYPE_BF16:
return convert_unary_cuda<nv_bfloat16, float>; return convert_unary_cuda<nv_bfloat16, float>;
default: default:

View file

@ -1,9 +1,9 @@
#include "cpy.cuh" #include "cpy.cuh"
#include "dequantize.cuh" #include "dequantize.cuh"
#include "cpy-utils.cuh" #include "cpy-utils.cuh"
#ifdef GGML_USE_MUSA #if defined(GGML_USE_MUSA) && defined(GGML_MUSA_MUDNN_COPY)
#include "ggml-musa/mudnn.cuh" #include "ggml-musa/mudnn.cuh"
#endif // GGML_USE_MUSA #endif // GGML_USE_MUSA && GGML_MUSA_MUDNN_COPY
typedef void (*cpy_kernel_t)(const char * cx, char * cdst); typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
@ -121,7 +121,7 @@ static __global__ void cpy_q_f32(const char * cx, char * cdst_direct, const int
// Copy destination pointers to GPU to be available when pointer indirection is in use // Copy destination pointers to GPU to be available when pointer indirection is in use
void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_dest_ptrs, const int host_dest_ptrs_size, cudaStream_t stream) { void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_dest_ptrs, const int host_dest_ptrs_size, cudaStream_t stream) {
#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) #if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS)
if (cuda_graph->dest_ptrs_size < host_dest_ptrs_size) { // (re-)allocate GPU memory for destination pointers if (cuda_graph->dest_ptrs_size < host_dest_ptrs_size) { // (re-)allocate GPU memory for destination pointers
CUDA_CHECK(cudaStreamSynchronize(stream)); CUDA_CHECK(cudaStreamSynchronize(stream));
if (cuda_graph->dest_ptrs_d != nullptr) { if (cuda_graph->dest_ptrs_d != nullptr) {
@ -314,7 +314,7 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
char ** dest_ptrs_d = nullptr; char ** dest_ptrs_d = nullptr;
int graph_cpynode_index = -1; int graph_cpynode_index = -1;
#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) #if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS)
if(ctx.cuda_graph->use_cpy_indirection && !disable_indirection_for_this_node) { if(ctx.cuda_graph->use_cpy_indirection && !disable_indirection_for_this_node) {
dest_ptrs_d = ctx.cuda_graph->dest_ptrs_d; dest_ptrs_d = ctx.cuda_graph->dest_ptrs_d;
graph_cpynode_index = ctx.cuda_graph->graph_cpynode_index; graph_cpynode_index = ctx.cuda_graph->graph_cpynode_index;
@ -324,11 +324,11 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
#endif #endif
if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) { if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1)); GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1));
#ifdef GGML_USE_MUSA #if defined(GGML_USE_MUSA) && defined(GGML_MUSA_MUDNN_COPY)
if (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16) { if (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16) {
CUDA_CHECK(mudnnMemcpyAsync(ctx, src1, src0)); CUDA_CHECK(mudnnMemcpyAsync(ctx, src1, src0));
} else } else
#endif // GGML_USE_MUSA #endif // GGML_USE_MUSA && GGML_MUSA_MUDNN_COPY
{ {
CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream)); CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
} }
@ -379,7 +379,7 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__, GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
ggml_type_name(src0->type), ggml_type_name(src1->type)); ggml_type_name(src0->type), ggml_type_name(src1->type));
} }
#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) #if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS)
if(ctx.cuda_graph->use_cpy_indirection && !disable_indirection_for_this_node) { if(ctx.cuda_graph->use_cpy_indirection && !disable_indirection_for_this_node) {
ctx.cuda_graph->graph_cpynode_index = graph_cpynode_index; ctx.cuda_graph->graph_cpynode_index = graph_cpynode_index;
} }

View file

@ -23,33 +23,13 @@ typedef void (* fattn_kernel_t)(
const float m1, const float m1,
const uint32_t n_head_log2, const uint32_t n_head_log2,
const float logit_softcap, const float logit_softcap,
const int ne00, const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
const int ne01, const int32_t nb01, const int32_t nb02, const int32_t nb03,
const int ne02, const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
const int ne03, const int32_t nb11, const int32_t nb12, const int64_t nb13,
const int ne10, const int32_t nb21, const int32_t nb22, const int64_t nb23,
const int ne11, const int32_t ne31, const int32_t ne32, const int32_t ne33,
const int ne12, const int32_t nb31, const int32_t nb32, const int64_t nb33);
const int ne13,
const int ne31,
const int ne32,
const int ne33,
const int nb31,
const int nb32,
const int nb33,
const int nb01,
const int nb02,
const int nb03,
const int nb11,
const int nb12,
const int nb13,
const int nb21,
const int nb22,
const int nb23,
const int ne0,
const int ne1,
const int ne2,
const int ne3);
typedef half (*vec_dot_KQ_f16_t)( typedef half (*vec_dot_KQ_f16_t)(
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds); const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds);
@ -745,33 +725,58 @@ void launch_fattn(
size_t nb23 = V ? V->nb[3] : nb13; size_t nb23 = V ? V->nb[3] : nb13;
if (need_f16_K && K->type != GGML_TYPE_F16) { if (need_f16_K && K->type != GGML_TYPE_F16) {
GGML_ASSERT(ggml_is_contiguously_allocated(K));
K_f16.alloc(ggml_nelements(K));
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type);
to_fp16(K_data, K_f16.ptr, ggml_nelements(K), main_stream);
K_data = (char *) K_f16.ptr;
const size_t bs = ggml_blck_size(K->type); const size_t bs = ggml_blck_size(K->type);
const size_t ts = ggml_type_size(K->type); const size_t ts = ggml_type_size(K->type);
K_f16.alloc(ggml_nelements(K));
if (ggml_is_contiguously_allocated(K)) {
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type);
to_fp16(K_data, K_f16.ptr, ggml_nelements(K), main_stream);
nb11 = nb11*bs*sizeof(half)/ts; nb11 = nb11*bs*sizeof(half)/ts;
nb12 = nb12*bs*sizeof(half)/ts; nb12 = nb12*bs*sizeof(half)/ts;
nb13 = nb13*bs*sizeof(half)/ts; nb13 = nb13*bs*sizeof(half)/ts;
} else {
GGML_ASSERT(K->nb[0] == ts);
to_fp16_nc_cuda_t to_fp16 = ggml_get_to_fp16_nc_cuda(K->type);
const int64_t s01 = nb11 / ts;
const int64_t s02 = nb12 / ts;
const int64_t s03 = nb13 / ts;
to_fp16(K_data, K_f16.ptr, K->ne[0], K->ne[1], K->ne[2], K->ne[3], s01, s02, s03, main_stream);
nb11 = K->ne[0] * sizeof(half);
nb12 = K->ne[1] * nb11;
nb13 = K->ne[2] * nb12;
}
K_data = (char *) K_f16.ptr;
} }
if (V && need_f16_V && V->type != GGML_TYPE_F16) { if (V && need_f16_V && V->type != GGML_TYPE_F16) {
GGML_ASSERT(ggml_is_contiguously_allocated(V)); const size_t bs = ggml_blck_size(V->type);
const size_t ts = ggml_type_size(V->type);
V_f16.alloc(ggml_nelements(V)); V_f16.alloc(ggml_nelements(V));
if (ggml_is_contiguously_allocated(V)) {
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type); to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream); to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream);
V_data = (char *) V_f16.ptr; V_data = (char *) V_f16.ptr;
const size_t bs = ggml_blck_size(V->type);
const size_t ts = ggml_type_size(V->type);
nb21 = nb21*bs*sizeof(half)/ts; nb21 = nb21*bs*sizeof(half)/ts;
nb22 = nb22*bs*sizeof(half)/ts; nb22 = nb22*bs*sizeof(half)/ts;
nb23 = nb23*bs*sizeof(half)/ts; nb23 = nb23*bs*sizeof(half)/ts;
} else {
GGML_ASSERT(V->nb[0] == ts);
to_fp16_nc_cuda_t to_fp16 = ggml_get_to_fp16_nc_cuda(V->type);
const int64_t s01 = nb21 / ts;
const int64_t s02 = nb22 / ts;
const int64_t s03 = nb23 / ts;
to_fp16(V_data, V_f16.ptr, V->ne[0], V->ne[1], V->ne[2], V->ne[3], s01, s02, s03, main_stream);
nb21 = V->ne[0] * sizeof(half);
nb22 = V->ne[1] * nb21;
nb23 = V->ne[2] * nb22;
}
V_data = (char *) V_f16.ptr;
} }
int parallel_blocks = 1; int parallel_blocks = 1;
@ -867,14 +872,11 @@ void launch_fattn(
mask ? ((const char *) mask->data) : nullptr, mask ? ((const char *) mask->data) : nullptr,
!stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr, !stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
scale, max_bias, m0, m1, n_head_log2, logit_softcap, scale, max_bias, m0, m1, n_head_log2, logit_softcap,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], Q->nb[1], Q->nb[2], Q->nb[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3], K->ne[0], K->ne[1], K->ne[2], K->ne[3], nb11, nb12, nb13,
mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0,
mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0, mask ? mask->nb[3] : 0,
Q->nb[1], Q->nb[2], Q->nb[3],
nb11, nb12, nb13,
nb21, nb22, nb23, nb21, nb22, nb23,
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0,
mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0, mask ? mask->nb[3] : 0
); );
CUDA_CHECK(cudaGetLastError()); CUDA_CHECK(cudaGetLastError());

View file

@ -408,7 +408,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
const int stride_K, const int stride_K,
const int stride_V, const int stride_V,
const int stride_mask, const int stride_mask,
const int jt,
half2 * const __restrict__ tile_Q, half2 * const __restrict__ tile_Q,
half2 * const __restrict__ tile_K, half2 * const __restrict__ tile_K,
half2 * const __restrict__ tile_V, half2 * const __restrict__ tile_V,
@ -455,7 +454,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
cp_async_wait_all(); cp_async_wait_all();
__syncthreads(); __syncthreads();
flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async> flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
(V_h2 + k_VKQ_0*stride_V, tile_V, nbatch_V2, stride_V); (V_h2 + int64_t(k_VKQ_0)*stride_V, tile_V, nbatch_V2, stride_V);
} else { } else {
constexpr bool use_cp_async = nstages == 1; constexpr bool use_cp_async = nstages == 1;
if (ncols2 > 1 || mask_h2) { if (ncols2 > 1 || mask_h2) {
@ -471,7 +470,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
if (nstages <= 1) { if (nstages <= 1) {
constexpr bool use_cp_async = nstages == 1; constexpr bool use_cp_async = nstages == 1;
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async> flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
(K_h2 + k_VKQ_0*stride_K + k0_start, tile_K, k0_diff, stride_K); (K_h2 + int64_t(k_VKQ_0)*stride_K + k0_start, tile_K, k0_diff, stride_K);
if (use_cp_async) { if (use_cp_async) {
cp_async_wait_all(); cp_async_wait_all();
} }
@ -715,7 +714,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
(mask_h2 + (k_VKQ_0 + c::nbatch_fa)/2, tile_mask, stride_mask); (mask_h2 + (k_VKQ_0 + c::nbatch_fa)/2, tile_mask, stride_mask);
} }
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async> flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
(K_h2 + (k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K); (K_h2 + int64_t(k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K);
} }
} }
@ -732,7 +731,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
if (nstages <= 1 && i0_start < reusable_cutoff) { if (nstages <= 1 && i0_start < reusable_cutoff) {
constexpr bool use_cp_async = nstages == 1; constexpr bool use_cp_async = nstages == 1;
flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async> flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
(V_h2 + k_VKQ_0*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V); (V_h2 + int64_t(k_VKQ_0)*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V);
if (use_cp_async) { if (use_cp_async) {
cp_async_wait_all(); cp_async_wait_all();
} }
@ -771,8 +770,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup); GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup);
GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap); GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap);
GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_K); GGML_UNUSED(stride_V); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_K); GGML_UNUSED(stride_V);
GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K); GGML_UNUSED(stride_mask); GGML_UNUSED(tile_K);
GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K);
GGML_UNUSED(tile_V); GGML_UNUSED(tile_mask); GGML_UNUSED(Q_B); GGML_UNUSED(tile_V); GGML_UNUSED(tile_mask); GGML_UNUSED(Q_B);
GGML_UNUSED(VKQ_C); GGML_UNUSED(KQ_max); GGML_UNUSED(KQ_rowsum); GGML_UNUSED(VKQ_C); GGML_UNUSED(KQ_max); GGML_UNUSED(KQ_rowsum);
GGML_UNUSED(kb0); GGML_UNUSED(tile_Q); GGML_UNUSED(kb0); GGML_UNUSED(tile_Q);
@ -920,7 +918,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
(mask_h2 + kb0_start*c::nbatch_fa/2, tile_mask, stride_mask); (mask_h2 + kb0_start*c::nbatch_fa/2, tile_mask, stride_mask);
} }
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async> flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
(K_h2 + kb0_start*c::nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K); (K_h2 + int64_t(kb0_start)*c::nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K);
} }
// Iterate over ne11 == previous tokens: // Iterate over ne11 == previous tokens:
@ -928,13 +926,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
constexpr bool last_iter = false; constexpr bool last_iter = false;
flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter> flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0); ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
} }
{ // kb0_start is always < kb0_stop so the last iter can be executed unconditionally. { // kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
constexpr bool last_iter = true; constexpr bool last_iter = true;
flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter> flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1); ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1);
} }
// With multi-stage loading there is no __syncthreads at the end of the iter, // With multi-stage loading there is no __syncthreads at the end of the iter,
@ -1214,33 +1212,13 @@ static __global__ void flash_attn_ext_f16(
const float m1, const float m1,
const uint32_t n_head_log2, const uint32_t n_head_log2,
const float logit_softcap, const float logit_softcap,
const int ne00, const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
const int ne01, const int32_t nb01, const int32_t nb02, const int32_t nb03,
const int ne02, const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
const int ne03, const int32_t nb11, const int32_t nb12, const int64_t nb13,
const int ne10, const int32_t nb21, const int32_t nb22, const int64_t nb23,
const int ne11, const int32_t ne31, const int32_t ne32, const int32_t ne33,
const int ne12, const int32_t nb31, const int32_t nb32, const int64_t nb33) {
const int ne13,
const int ne31,
const int ne32,
const int ne33,
const int nb31,
const int nb32,
const int nb33,
const int nb01,
const int nb02,
const int nb03,
const int nb11,
const int nb12,
const int nb13,
const int nb21,
const int nb22,
const int nb23,
const int ne0,
const int ne1,
const int ne2,
const int ne3) {
#if defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE) #if defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
// Skip unused kernel variants for faster compilation: // Skip unused kernel variants for faster compilation:
@ -1359,8 +1337,7 @@ static __global__ void flash_attn_ext_f16(
GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21);
GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
GGML_UNUSED(ne2); GGML_UNUSED(ne3);
NO_DEVICE_CODE; NO_DEVICE_CODE;
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE) #endif // defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
} }

View file

@ -21,33 +21,13 @@ static __global__ void flash_attn_tile_ext_f16(
const float m1, const float m1,
const uint32_t n_head_log2, const uint32_t n_head_log2,
const float logit_softcap, const float logit_softcap,
const int ne00, const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
const int ne01, const int32_t nb01, const int32_t nb02, const int32_t nb03,
const int ne02, const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
const int ne03, const int32_t nb11, const int32_t nb12, const int64_t nb13,
const int ne10, const int32_t nb21, const int32_t nb22, const int64_t nb23,
const int ne11, const int32_t ne31, const int32_t ne32, const int32_t ne33,
const int ne12, const int32_t nb31, const int32_t nb32, const int64_t nb33) {
const int ne13,
const int ne31,
const int ne32,
const int ne33,
const int nb31,
const int nb32,
const int nb33,
const int nb01,
const int nb02,
const int nb03,
const int nb11,
const int nb12,
const int nb13,
const int nb21,
const int nb22,
const int nb23,
const int ne0,
const int ne1,
const int ne2,
const int ne3) {
#if defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE) #if defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
// Skip unused kernel variants for faster compilation: // Skip unused kernel variants for faster compilation:
@ -127,7 +107,7 @@ static __global__ void flash_attn_tile_ext_f16(
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) { for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
const int k_KQ = k_KQ_0 + threadIdx.x; const int k_KQ = k_KQ_0 + threadIdx.x;
KV_tmp[i_KQ][k_KQ] = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ]; KV_tmp[i_KQ][k_KQ] = K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ];
} }
} }
@ -221,7 +201,7 @@ static __global__ void flash_attn_tile_ext_f16(
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x; const int i = i0 + threadIdx.x;
KV_tmp[k][i] = V_h2[(k_VKQ_0 + k)*stride_KV2 + i]; KV_tmp[k][i] = V_h2[int64_t(k_VKQ_0 + k)*stride_KV2 + i];
} }
} }
@ -300,8 +280,7 @@ static __global__ void flash_attn_tile_ext_f16(
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(nb23);
GGML_UNUSED(ne2); GGML_UNUSED(ne3);
NO_DEVICE_CODE; NO_DEVICE_CODE;
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE) #endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
} }

View file

@ -21,33 +21,13 @@ static __global__ void flash_attn_tile_ext_f32(
const float m1, const float m1,
const uint32_t n_head_log2, const uint32_t n_head_log2,
const float logit_softcap, const float logit_softcap,
const int ne00, const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
const int ne01, const int32_t nb01, const int32_t nb02, const int32_t nb03,
const int ne02, const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
const int ne03, const int32_t nb11, const int32_t nb12, const int64_t nb13,
const int ne10, const int32_t nb21, const int32_t nb22, const int64_t nb23,
const int ne11, const int32_t ne31, const int32_t ne32, const int32_t ne33,
const int ne12, const int32_t nb31, const int32_t nb32, const int64_t nb33) {
const int ne13,
const int ne31,
const int ne32,
const int ne33,
const int nb31,
const int nb32,
const int nb33,
const int nb01,
const int nb02,
const int nb03,
const int nb11,
const int nb12,
const int nb13,
const int nb21,
const int nb22,
const int nb23,
const int ne0,
const int ne1,
const int ne2,
const int ne3) {
#ifdef FLASH_ATTN_AVAILABLE #ifdef FLASH_ATTN_AVAILABLE
// Skip unused kernel variants for faster compilation: // Skip unused kernel variants for faster compilation:
@ -66,8 +46,7 @@ static __global__ void flash_attn_tile_ext_f32(
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(nb23);
GGML_UNUSED(ne2); GGML_UNUSED(ne3);
NO_DEVICE_CODE; NO_DEVICE_CODE;
return; return;
} }
@ -135,7 +114,7 @@ static __global__ void flash_attn_tile_ext_f32(
#pragma unroll #pragma unroll
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 2*WARP_SIZE) { for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 2*WARP_SIZE) {
const half2 tmp = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + threadIdx.x]; const half2 tmp = K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + threadIdx.x];
KV_tmp[i_KQ][k_KQ_0 + 0*WARP_SIZE + threadIdx.x] = __low2float(tmp); KV_tmp[i_KQ][k_KQ_0 + 0*WARP_SIZE + threadIdx.x] = __low2float(tmp);
KV_tmp[i_KQ][k_KQ_0 + 1*WARP_SIZE + threadIdx.x] = __high2float(tmp); KV_tmp[i_KQ][k_KQ_0 + 1*WARP_SIZE + threadIdx.x] = __high2float(tmp);
} }
@ -231,8 +210,9 @@ static __global__ void flash_attn_tile_ext_f32(
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x; const int i = i0 + threadIdx.x;
KV_tmp2[k*(D/2) + i].x = __low2float(V_h2[(k_VKQ_0 + k)*stride_KV2 + i]); const half2 tmp = V_h2[int64_t(k_VKQ_0 + k)*stride_KV2 + i];
KV_tmp2[k*(D/2) + i].y = __high2float(V_h2[(k_VKQ_0 + k)*stride_KV2 + i]); KV_tmp2[k*(D/2) + i].x = __low2float(tmp);
KV_tmp2[k*(D/2) + i].y = __high2float(tmp);
} }
} }
@ -312,7 +292,6 @@ static __global__ void flash_attn_tile_ext_f32(
GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3);
NO_DEVICE_CODE; NO_DEVICE_CODE;
#endif // FLASH_ATTN_AVAILABLE #endif // FLASH_ATTN_AVAILABLE
} }

View file

@ -18,33 +18,13 @@ static __global__ void flash_attn_vec_ext_f16(
const float m1, const float m1,
const uint32_t n_head_log2, const uint32_t n_head_log2,
const float logit_softcap, const float logit_softcap,
const int ne00, const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
const int ne01, const int32_t nb01, const int32_t nb02, const int32_t nb03,
const int ne02, const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
const int ne03, const int32_t nb11, const int32_t nb12, const int64_t nb13,
const int ne10, const int32_t nb21, const int32_t nb22, const int64_t nb23,
const int ne11, const int32_t ne31, const int32_t ne32, const int32_t ne33,
const int ne12, const int32_t nb31, const int32_t nb32, const int64_t nb33) {
const int ne13,
const int ne31,
const int ne32,
const int ne33,
const int nb31,
const int nb32,
const int nb33,
const int nb01,
const int nb02,
const int nb03,
const int nb11,
const int nb12,
const int nb13,
const int nb21,
const int nb22,
const int nb23,
const int ne0,
const int ne1,
const int ne2,
const int ne3) {
#if defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE) #if defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
// Skip unused kernel variants for faster compilation: // Skip unused kernel variants for faster compilation:
@ -191,13 +171,16 @@ static __global__ void flash_attn_vec_ext_f16(
half2 VKQ[ncols] = {{0.0f, 0.0f}}; half2 VKQ[ncols] = {{0.0f, 0.0f}};
K += blockIdx.y*D * nb11;
V += blockIdx.y*D * nb21;
maskh += blockIdx.y*D;
for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) { for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) {
// Calculate KQ tile and keep track of new maximum KQ values: // Calculate KQ tile and keep track of new maximum KQ values:
if (mask) { if (mask) {
#pragma unroll #pragma unroll
for (int j = 0; j < ncols; ++j) { for (int j = 0; j < ncols; ++j) {
maskh_shared[j*D + tid] = slopeh*maskh[j*ne11 + k_VKQ_0 + tid]; maskh_shared[j*D + tid] = slopeh*maskh[j*ne11 + tid];
} }
__syncthreads(); __syncthreads();
@ -244,7 +227,7 @@ static __global__ void flash_attn_vec_ext_f16(
#pragma unroll #pragma unroll
for (int j = 0; j < ncols; ++j) { for (int j = 0; j < ncols; ++j) {
half sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_h2[j], Q_i32[j], Q_ds[j]); half sum = vec_dot_KQ(K + i_KQ*nb11, Q_h2[j], Q_i32[j], Q_ds[j]);
sum = warp_reduce_sum((float)sum); sum = warp_reduce_sum((float)sum);
if (use_logit_softcap) { if (use_logit_softcap) {
@ -300,14 +283,18 @@ static __global__ void flash_attn_vec_ext_f16(
} }
half2 V_k; half2 V_k;
reinterpret_cast<half&>(V_k.x) = dequantize_1_v(V + (k_VKQ_0 + k0 + 0)*nb21, tid); reinterpret_cast<half&>(V_k.x) = dequantize_1_v(V + (k0 + 0)*nb21, tid);
reinterpret_cast<half&>(V_k.y) = dequantize_1_v(V + (k_VKQ_0 + k0 + 1)*nb21, tid); reinterpret_cast<half&>(V_k.y) = dequantize_1_v(V + (k0 + 1)*nb21, tid);
#pragma unroll #pragma unroll
for (int j = 0; j < ncols; ++j) { for (int j = 0; j < ncols; ++j) {
VKQ[j] += V_k*KQ2[j*(D/2) + k0/2]; VKQ[j] += V_k*KQ2[j*(D/2) + k0/2];
} }
} }
K += gridDim.y*D * nb11;
V += gridDim.y*D * nb21;
maskh += gridDim.y*D;
__syncthreads(); __syncthreads();
} }
@ -351,8 +338,7 @@ static __global__ void flash_attn_vec_ext_f16(
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(nb23);
GGML_UNUSED(ne2); GGML_UNUSED(ne3);
NO_DEVICE_CODE; NO_DEVICE_CODE;
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE) #endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
} }

View file

@ -18,33 +18,13 @@ static __global__ void flash_attn_vec_ext_f32(
const float m1, const float m1,
const uint32_t n_head_log2, const uint32_t n_head_log2,
const float logit_softcap, const float logit_softcap,
const int ne00, const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
const int ne01, const int32_t nb01, const int32_t nb02, const int32_t nb03,
const int ne02, const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
const int ne03, const int32_t nb11, const int32_t nb12, const int64_t nb13,
const int ne10, const int32_t nb21, const int32_t nb22, const int64_t nb23,
const int ne11, const int32_t ne31, const int32_t ne32, const int32_t ne33,
const int ne12, const int32_t nb31, const int32_t nb32, const int64_t nb33) {
const int ne13,
const int ne31,
const int ne32,
const int ne33,
const int nb31,
const int nb32,
const int nb33,
const int nb01,
const int nb02,
const int nb03,
const int nb11,
const int nb12,
const int nb13,
const int nb21,
const int nb22,
const int nb23,
const int ne0,
const int ne1,
const int ne2,
const int ne3) {
#ifdef FLASH_ATTN_AVAILABLE #ifdef FLASH_ATTN_AVAILABLE
// Skip unused kernel variants for faster compilation: // Skip unused kernel variants for faster compilation:
@ -59,8 +39,7 @@ static __global__ void flash_attn_vec_ext_f32(
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(nb23);
GGML_UNUSED(ne2); GGML_UNUSED(ne3);
NO_DEVICE_CODE; NO_DEVICE_CODE;
return; return;
} }
@ -198,13 +177,16 @@ static __global__ void flash_attn_vec_ext_f32(
float VKQ[ncols] = {0.0f}; float VKQ[ncols] = {0.0f};
K += blockIdx.y*D * nb11;
V += blockIdx.y*D * nb21;
maskh += blockIdx.y*D;
for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) { for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) {
// Calculate KQ tile and keep track of new maximum KQ values: // Calculate KQ tile and keep track of new maximum KQ values:
if (mask) { if (mask) {
#pragma unroll #pragma unroll
for (int j = 0; j < ncols; ++j) { for (int j = 0; j < ncols; ++j) {
maskf_shared[j*D + tid] = slope*__half2float(maskh[j*ne11 + k_VKQ_0 + tid]); maskf_shared[j*D + tid] = slope*__half2float(maskh[j*ne11 + tid]);
} }
__syncthreads(); __syncthreads();
@ -246,7 +228,7 @@ static __global__ void flash_attn_vec_ext_f32(
#pragma unroll #pragma unroll
for (int j = 0; j < ncols; ++j) { for (int j = 0; j < ncols; ++j) {
float sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_f2[j], Q_i32[j], Q_ds[j]); float sum = vec_dot_KQ(K + i_KQ*nb11, Q_f2[j], Q_i32[j], Q_ds[j]);
sum = warp_reduce_sum(sum); sum = warp_reduce_sum(sum);
if (use_logit_softcap) { if (use_logit_softcap) {
@ -297,13 +279,17 @@ static __global__ void flash_attn_vec_ext_f32(
break; break;
} }
const float V_ki = dequantize_1_v(V + (k_VKQ_0 + k)*nb21, tid); const float V_ki = dequantize_1_v(V + k*nb21, tid);
#pragma unroll #pragma unroll
for (int j = 0; j < ncols; ++j) { for (int j = 0; j < ncols; ++j) {
VKQ[j] += V_ki*KQ[j*D + k]; VKQ[j] += V_ki*KQ[j*D + k];
} }
} }
K += gridDim.y*D * nb11;
V += gridDim.y*D * nb21;
maskh += gridDim.y*D;
__syncthreads(); __syncthreads();
} }
@ -348,7 +334,6 @@ static __global__ void flash_attn_vec_ext_f32(
GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3);
NO_DEVICE_CODE; NO_DEVICE_CODE;
#endif // FLASH_ATTN_AVAILABLE #endif // FLASH_ATTN_AVAILABLE
} }

View file

@ -37,33 +37,13 @@ static __global__ void flash_attn_ext_f16(
const float m1, const float m1,
const uint32_t n_head_log2, const uint32_t n_head_log2,
const float logit_softcap, const float logit_softcap,
const int ne00, const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
const int ne01, const int32_t nb01, const int32_t nb02, const int32_t nb03,
const int ne02, const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
const int ne03, const int32_t nb11, const int32_t nb12, const int64_t nb13,
const int ne10, const int32_t nb21, const int32_t nb22, const int64_t nb23,
const int ne11, const int32_t ne31, const int32_t ne32, const int32_t ne33,
const int ne12, const int32_t nb31, const int32_t nb32, const int64_t nb33) {
const int ne13,
const int ne31,
const int ne32,
const int ne33,
const int nb31,
const int nb32,
const int nb33,
const int nb01,
const int nb02,
const int nb03,
const int nb11,
const int nb12,
const int nb13,
const int nb21,
const int nb22,
const int nb23,
const int ne0,
const int ne1,
const int ne2,
const int ne3) {
#if !defined(GGML_HIP_NO_ROCWMMA_FATTN) && defined(FLASH_ATTN_AVAILABLE) && ((__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || __CUDA_ARCH__ == GGML_CUDA_CC_TURING) || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE))) #if !defined(GGML_HIP_NO_ROCWMMA_FATTN) && defined(FLASH_ATTN_AVAILABLE) && ((__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || __CUDA_ARCH__ == GGML_CUDA_CC_TURING) || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)))
// Skip unused kernel variants for faster compilation: // Skip unused kernel variants for faster compilation:
if (use_logit_softcap && !(D == 128 || D == 256)) { if (use_logit_softcap && !(D == 128 || D == 256)) {
@ -197,7 +177,7 @@ static __global__ void flash_attn_ext_f16(
#pragma unroll #pragma unroll
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) { for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
frag_a_K K_a; frag_a_K K_a;
wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV); wmma::load_matrix_sync(K_a, K_h + int64_t(k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
#pragma unroll #pragma unroll
for (int j = 0; j < ncols/frag_n; ++j) { for (int j = 0; j < ncols/frag_n; ++j) {
wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]); wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
@ -344,7 +324,7 @@ static __global__ void flash_attn_ext_f16(
const int k = k0 + (threadIdx.y % VKQ_ratio)*16; const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
frag_a_V v_a; frag_a_V v_a;
wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV); wmma::load_matrix_sync(v_a, V_h + int64_t(k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
#pragma unroll #pragma unroll
for (int j = 0; j < ncols/frag_n; ++j) { for (int j = 0; j < ncols/frag_n; ++j) {
wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]); wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
@ -451,7 +431,6 @@ static __global__ void flash_attn_ext_f16(
GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3);
NO_DEVICE_CODE; NO_DEVICE_CODE;
#endif // defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE))) #endif // defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)))
} }

View file

@ -280,23 +280,13 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size; const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV); const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
if (GGML_CUDA_CC_IS_AMD(cc)) {
#if defined(GGML_HIP_ROCWMMA_FATTN) #if defined(GGML_HIP_ROCWMMA_FATTN)
if (fp16_mma_available(cc)) { if (GGML_CUDA_CC_IS_AMD(cc) && fp16_mma_available(cc)) {
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst); ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
return; return;
} }
#endif // defined(GGML_HIP_ROCWMMA_FATTN) #endif // defined(GGML_HIP_ROCWMMA_FATTN)
// On AMD the tile kernels perform poorly, use the vec kernel instead:
if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
} else {
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
}
return;
}
if (!fast_fp16_available(cc)) { if (!fast_fp16_available(cc)) {
if (Q->ne[1] <= 8 || Q->ne[0] == 256) { if (Q->ne[1] <= 8 || Q->ne[0] == 256) {
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst); ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);

View file

@ -57,6 +57,7 @@ bool g_mul_mat_q = true;
#include <cstddef> #include <cstddef>
#include <cstdint> #include <cstdint>
#include <float.h> #include <float.h>
#include <initializer_list>
#include <limits> #include <limits>
#include <map> #include <map>
#include <memory> #include <memory>
@ -2770,6 +2771,39 @@ static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
} }
#endif #endif
static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops) {
if (!ggml_can_fuse(cgraph, node_idx, ops)) {
return false;
}
if (ops.size() == 2 && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) {
const ggml_tensor *rms_norm = cgraph->nodes[node_idx];
const ggml_tensor *mul = cgraph->nodes[node_idx+1];
GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(rms_norm->type == GGML_TYPE_F32);
//rms norm only supports F32
if (mul->src[0]->type != GGML_TYPE_F32 ||
mul->src[1]->type != GGML_TYPE_F32 ||
mul->type != GGML_TYPE_F32) {
return false;
}
//if rms norm is the B operand, then we don't handle broadcast
if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm->src[1])) {
return false;
}
//rms_norm kernel assumes contigous rows
if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) {
return false;
}
}
return true;
}
static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
bool & graph_evaluated_or_captured, bool & use_cuda_graph, bool & cuda_graph_update_required) { bool & graph_evaluated_or_captured, bool & use_cuda_graph, bool & cuda_graph_update_required) {
// flag used to determine whether it is an integrated_gpu // flag used to determine whether it is an integrated_gpu
@ -2779,6 +2813,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
// Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph. // Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph.
// With the use of CUDA graphs, the execution will be performed by the graph launch. // With the use of CUDA graphs, the execution will be performed by the graph launch.
if (!use_cuda_graph || cuda_graph_update_required) { if (!use_cuda_graph || cuda_graph_update_required) {
for (int i = 0; i < cgraph->n_nodes; i++) { for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_tensor * node = cgraph->nodes[i]; ggml_tensor * node = cgraph->nodes[i];
@ -2786,6 +2821,12 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
continue; continue;
} }
static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr);
if (!disable_fusion && ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i+1]);
i++;
continue;
}
#ifndef NDEBUG #ifndef NDEBUG
assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device)); assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device));
for (int j = 0; j < GGML_MAX_SRC; j++) { for (int j = 0; j < GGML_MAX_SRC; j++) {

View file

@ -104,10 +104,12 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
} }
} }
template <int block_size> template <int block_size, bool do_multiply = false>
static __global__ void rms_norm_f32( static __global__ void rms_norm_f32(
const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel, const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
const int64_t stride_sample, const float eps) { const int64_t stride_sample, const float eps, const float * mul = nullptr, const int64_t mul_stride_row = 0,
const int64_t mul_stride_channel = 0, const int64_t mul_stride_sample = 0, const int mul_ncols = 0,
const int mul_nrows = 0, const int mul_nchannels = 0, const int mul_nsamples = 0) {
const int nrows = gridDim.x; const int nrows = gridDim.x;
const int nchannels = gridDim.y; const int nchannels = gridDim.y;
@ -119,6 +121,13 @@ static __global__ void rms_norm_f32(
x += sample*stride_sample + channel*stride_channel + row*stride_row; x += sample*stride_sample + channel*stride_channel + row*stride_row;
dst += ((sample*nchannels + channel)*nrows + row)*ncols; dst += ((sample*nchannels + channel)*nrows + row)*ncols;
if constexpr (do_multiply) {
const int mul_row = row % mul_nrows;
const int mul_channel = channel % mul_nchannels;
const int mul_sample = sample % mul_nsamples;
mul += mul_sample*mul_stride_sample + mul_channel*mul_stride_channel + mul_row*mul_stride_row;
}
float tmp = 0.0f; // partial sum for thread in warp float tmp = 0.0f; // partial sum for thread in warp
for (int col = tid; col < ncols; col += block_size) { for (int col = tid; col < ncols; col += block_size) {
@ -145,8 +154,13 @@ static __global__ void rms_norm_f32(
const float scale = rsqrtf(mean + eps); const float scale = rsqrtf(mean + eps);
for (int col = tid; col < ncols; col += block_size) { for (int col = tid; col < ncols; col += block_size) {
if constexpr (do_multiply) {
const int mul_col = col % mul_ncols;
dst[col] = scale * x[col] * mul[mul_col];
} else {
dst[col] = scale * x[col]; dst[col] = scale * x[col];
} }
}
} }
template <int block_size> template <int block_size>
@ -310,10 +324,30 @@ static void rms_norm_f32_cuda(
const dim3 blocks_num(nrows, nchannels, nsamples); const dim3 blocks_num(nrows, nchannels, nsamples);
if (ncols < 1024) { if (ncols < 1024) {
const dim3 block_dims(WARP_SIZE, 1, 1); const dim3 block_dims(WARP_SIZE, 1, 1);
rms_norm_f32<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); rms_norm_f32<WARP_SIZE, false><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
} else { } else {
const dim3 block_dims(1024, 1, 1); const dim3 block_dims(1024, 1, 1);
rms_norm_f32<1024><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); rms_norm_f32<1024, false><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
}
}
static void rms_norm_mul_f32_cuda(
const float * x, const float * mul, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample,
const int64_t mul_stride_row, const int64_t mul_stride_channel, const int64_t mul_stride_sample,
const int mul_ncols, const int mul_nrows, const int mul_nchannels, const int mul_nsamples,
const float eps, cudaStream_t stream) {
const dim3 blocks_num(nrows, nchannels, nsamples);
if (mul == nullptr) {
rms_norm_f32_cuda(x, dst, ncols, nrows, nchannels, nsamples, stride_row, stride_channel, stride_sample, eps, stream);
return;
}
if (ncols < 1024) {
const dim3 block_dims(WARP_SIZE, 1, 1);
rms_norm_f32<WARP_SIZE, true><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, mul_stride_sample, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples);
} else {
const dim3 block_dims(1024, 1, 1);
rms_norm_f32<1024, true><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, mul_stride_sample, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples);
} }
} }
@ -407,6 +441,59 @@ void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
rms_norm_f32_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream); rms_norm_f32_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream);
} }
void ggml_cuda_op_rms_norm_fused(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * mul_tensor) {
const ggml_tensor * rms_norm_src = (ggml_tensor *) dst->src[0];
float eps = 0.0f;
memcpy(&eps, dst->op_params, sizeof(float));
const float * src0_d = (const float *) rms_norm_src->data;
const float * mul_d = nullptr;
const ggml_tensor * mul_src = nullptr;
if (mul_tensor->src[0] == dst) {
mul_d = (float *) mul_tensor->src[1]->data;
mul_src = mul_tensor->src[1];
} else if(mul_tensor->src[1] == dst) {
mul_d = (float *) mul_tensor->src[0]->data;
mul_src = mul_tensor->src[0];
} else {
GGML_ASSERT(false);
}
float * dst_d = (float *) mul_tensor->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(rms_norm_src->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(mul_tensor->type == GGML_TYPE_F32);
GGML_ASSERT(eps >= 0.0f);
const int64_t ne00 = rms_norm_src->ne[0];
const int64_t ne01 = rms_norm_src->ne[1];
const int64_t ne02 = rms_norm_src->ne[2];
const int64_t ne03 = rms_norm_src->ne[3];
const size_t ts0 = ggml_type_size(rms_norm_src->type);
GGML_ASSERT(rms_norm_src->nb[0] == ts0);
const int64_t s01 = rms_norm_src->nb[1] / ts0;
const int64_t s02 = rms_norm_src->nb[2] / ts0;
const int64_t s03 = rms_norm_src->nb[3] / ts0;
const size_t ts_mul = ggml_type_size(mul_src->type);
GGML_ASSERT(mul_src->nb[0] == ts_mul);
const int64_t mul_s01 = mul_src->nb[1] / ts_mul;
const int64_t mul_s02 = mul_src->nb[2] / ts_mul;
const int64_t mul_s03 = mul_src->nb[3] / ts_mul;
const int mul_ncols = mul_src->ne[0];
const int mul_nrows = mul_src->ne[1];
const int mul_nchannels = mul_src->ne[2];
const int mul_nsamples = mul_src->ne[3];
rms_norm_mul_f32_cuda(src0_d, mul_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, mul_s01, mul_s02, mul_s03, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples, eps, stream);
}
void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * grad = dst->src[0]; // gradients const ggml_tensor * grad = dst->src[0]; // gradients
const ggml_tensor * src0f = dst->src[1]; // src0 from forward pass const ggml_tensor * src0f = dst->src[1]; // src0 from forward pass

View file

@ -6,6 +6,8 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_rms_norm_fused(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * mul_tensor);
void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_l2_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_l2_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View file

@ -13,7 +13,7 @@
#define CUBLAS_OP_N MUBLAS_OP_N #define CUBLAS_OP_N MUBLAS_OP_N
#define CUBLAS_OP_T MUBLAS_OP_T #define CUBLAS_OP_T MUBLAS_OP_T
#define CUBLAS_STATUS_SUCCESS MUBLAS_STATUS_SUCCESS #define CUBLAS_STATUS_SUCCESS MUBLAS_STATUS_SUCCESS
#define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_MATH_MODE_DEFAULT #define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_TENSOR_OP_MATH
#define CUDA_R_16F MUSA_R_16F #define CUDA_R_16F MUSA_R_16F
#define CUDA_R_16BF MUSA_R_16BF #define CUDA_R_16BF MUSA_R_16BF
#define CUDA_R_32F MUSA_R_32F #define CUDA_R_32F MUSA_R_32F
@ -29,7 +29,7 @@
#define cublasSgemm mublasSgemm #define cublasSgemm mublasSgemm
#define cublasStatus_t mublasStatus_t #define cublasStatus_t mublasStatus_t
#define cublasOperation_t mublasOperation_t #define cublasOperation_t mublasOperation_t
#define cublasGetStatusString mublasStatus_to_string #define cublasGetStatusString mublasGetStatusString
#define cudaDataType_t musaDataType_t #define cudaDataType_t musaDataType_t
#define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer #define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer
#define cudaDeviceDisablePeerAccess musaDeviceDisablePeerAccess #define cudaDeviceDisablePeerAccess musaDeviceDisablePeerAccess

View file

@ -1955,6 +1955,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
static int ggml_metal_encode_node( static int ggml_metal_encode_node(
ggml_backend_t backend, ggml_backend_t backend,
int idx, int idx,
int idx_end,
id<MTLComputeCommandEncoder> encoder, id<MTLComputeCommandEncoder> encoder,
struct ggml_metal_mem_pool * mem_pool) { struct ggml_metal_mem_pool * mem_pool) {
struct ggml_backend_metal_context * ctx = backend->context; struct ggml_backend_metal_context * ctx = backend->context;
@ -2181,7 +2182,9 @@ static int ggml_metal_encode_node(
size_t offs_fuse; size_t offs_fuse;
id<MTLBuffer> id_fuse; id<MTLBuffer> id_fuse;
for (n_fuse = 0; n_fuse <= 6; ++n_fuse) { // note: in metal, we sometimes encode the graph in parallel so we have to avoid fusing nodes
// across splits. idx_end indicates the last node in the current split
for (n_fuse = 0; n_fuse <= 6 && idx + n_fuse + 1 < idx_end; ++n_fuse) {
if (!ggml_can_fuse(gf, idx + n_fuse, ops + n_fuse, 2)) { if (!ggml_can_fuse(gf, idx + n_fuse, ops + n_fuse, 2)) {
break; break;
} }
@ -4288,7 +4291,7 @@ static int ggml_metal_encode_node(
ops[1] = GGML_OP_MUL; ops[1] = GGML_OP_MUL;
ops[2] = GGML_OP_ADD; ops[2] = GGML_OP_ADD;
for (n_fuse = 0; n_fuse <= 1; ++n_fuse) { for (n_fuse = 0; n_fuse <= 1 && idx + n_fuse + 1 < idx_end; ++n_fuse) {
if (!ggml_can_fuse(gf, idx + n_fuse, ops + n_fuse, 2)) { if (!ggml_can_fuse(gf, idx + n_fuse, ops + n_fuse, 2)) {
break; break;
} }
@ -6271,7 +6274,11 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
[encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]]; [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
} }
const int res = ggml_metal_encode_node(backend, idx, encoder, mem_pool); const int res = ggml_metal_encode_node(backend, idx, node_end, encoder, mem_pool);
if (idx + res > node_end) {
GGML_ABORT("fusion error: nodes spanning multiple encoders have been fused. this indicates a bug in the fusion logic %s",
"https://github.com/ggml-org/llama.cpp/pull/14849");
}
if (should_capture) { if (should_capture) {
[encoder popDebugGroup]; [encoder popDebugGroup];

View file

@ -48,11 +48,11 @@ template <> struct block_q_t<GGML_TYPE_Q4_0> {
}; };
static constexpr std::pair<int, int> get_block_offset(const int block_index, const int /* nblocks */) { static constexpr std::pair<int, int> get_block_offset(const int block_index, const int /* nblocks */) {
return { block_index * (traits::qk / traits::qr), 0 }; return { block_index * (QK4_0 / QR4_0), 0 };
} }
static constexpr std::pair<int, int> get_d_offset(int nrows, int ncols, const int block_index) { static constexpr std::pair<int, int> get_d_offset(int nrows, int ncols, const int block_index) {
return { (ncols / traits::qr * nrows) + block_index * sizeof(ggml_half), 0 }; return { (ncols / QR4_0 * nrows) + block_index * sizeof(ggml_half), 0 };
} }
static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; } static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; }
@ -71,14 +71,12 @@ template <> struct block_q_t<GGML_TYPE_Q4_K> {
} }
static constexpr std::pair<int, int> get_d_offset(int nrows, int ncols, const int block_index) { static constexpr std::pair<int, int> get_d_offset(int nrows, int ncols, const int block_index) {
auto nblocks = (nrows * (ncols / traits::qk)); auto nblocks = (nrows * (ncols / QK_K));
return { nblocks * (QK_K / 2), return { nblocks * (QK_K / 2) + (block_index * K_SCALE_SIZE),
(nblocks * QK_K / 2) + (nblocks * K_SCALE_SIZE) + (block_index * sizeof(ggml_half2)) }; (nblocks * QK_K / 2) + (nblocks * K_SCALE_SIZE) + (block_index * sizeof(ggml_half2)) };
} }
static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; } static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; }
constexpr size_t get_total_qs_bytes(int nblocks) { return nblocks * QK_K / 2; }
}; };
template <> struct block_q_t<GGML_TYPE_Q6_K> { template <> struct block_q_t<GGML_TYPE_Q6_K> {
@ -90,22 +88,23 @@ template <> struct block_q_t<GGML_TYPE_Q6_K> {
}; };
static constexpr std::pair<int, int> get_block_offset(const int block_index, const int n_blocks) { static constexpr std::pair<int, int> get_block_offset(const int block_index, const int n_blocks) {
auto low_bits_index = block_index * (traits::qk / traits::qr); auto low_bits_index = block_index * (QK_K / QR6_K);
// the index of high bits it's after all low bits // the index of high bits it's after all low bits
auto high_bits_index = n_blocks * (QK_K / 2) + (block_index * (QK_K / 4)); auto high_bits_index = n_blocks * (QK_K / 2) + (block_index * (QK_K / 4));
return { low_bits_index, high_bits_index }; return { low_bits_index, high_bits_index };
} }
static constexpr std::pair<int, int> get_d_offset(int nrows, int ncols, const int block_index) { static constexpr std::pair<int, int> get_d_offset(int nrows, int ncols, const int block_index) {
auto nblocks = (nrows * (ncols / traits::qk)); auto nblocks = (nrows * (ncols / QK_K));
auto total_qs_bytes = nblocks * (QK_K / 2) + nblocks * (QK_K / 4); auto total_qs_bytes = nblocks * (QK_K / 2) + nblocks * (QK_K / 4);
auto block_scales = total_qs_bytes + block_index * (QK_K / 16); auto block_scales = total_qs_bytes + block_index * (QK_K / 16);
auto sb_scale = total_qs_bytes + nblocks * (QK_K / 16); auto sb_scale = total_qs_bytes + nblocks * (QK_K / 16) + block_index * sizeof(ggml_half);
return { block_scales, sb_scale }; return { block_scales, sb_scale };
} }
static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; } static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; }
}; };
} // namespace ggml_sycl_reordered } // namespace ggml_sycl_reordered
#endif // GGML_SYCL_QUANTS_HPP #endif // GGML_SYCL_QUANTS_HPP

View file

@ -10272,7 +10272,7 @@ static bool ggml_vk_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, st
} }
// if rms_norm is the B operand, then we don't handle broadcast // if rms_norm is the B operand, then we don't handle broadcast
if (rms_norm == mul->src[1] && if (rms_norm == mul->src[1] &&
mul->src[0]->ne[1] != rms_norm->ne[1]) { !ggml_are_same_shape(mul->src[0], rms_norm)) {
return false; return false;
} }
// rms_norm shader assumes contiguous rows // rms_norm shader assumes contiguous rows

View file

@ -50,9 +50,15 @@ void main() {
const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1)); const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1));
if (do_multiply) { if (do_multiply) {
if (ncols > p.ne10) {
[[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + fastmod(col, p.ne10)]));
}
} else {
[[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) { [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + col])); data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + col]));
} }
}
} else { } else {
[[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) { [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col])); data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]));

View file

@ -144,6 +144,10 @@ class Metadata:
# Quick hack to fix the Norway problem # Quick hack to fix the Norway problem
# https://hitchdev.com/strictyaml/why/implicit-typing-removed/ # https://hitchdev.com/strictyaml/why/implicit-typing-removed/
yaml_content = yaml_content.replace("- no\n", "- \"no\"\n") yaml_content = yaml_content.replace("- no\n", "- \"no\"\n")
# yaml should use 2 spaces insted of tab
# this issue has came up with the Qwen/Qwen3-235B-A22B-Instruct-2507 model card
# (I've also sent a pr tp fix the modelcard too)
yaml_content = yaml_content.replace("\t", " ")
if yaml_content: if yaml_content:
data = yaml.safe_load(yaml_content) data = yaml.safe_load(yaml_content)

View file

@ -959,6 +959,7 @@ extern "C" {
// in the order they have appeared in the batch. // in the order they have appeared in the batch.
// Rows: number of tokens for which llama_batch.logits[i] != 0 // Rows: number of tokens for which llama_batch.logits[i] != 0
// Cols: n_vocab // Cols: n_vocab
// TODO: deprecate in favor of llama_get_logits_ith() (ref: https://github.com/ggml-org/llama.cpp/pull/14853#issuecomment-3113143522)
LLAMA_API float * llama_get_logits(struct llama_context * ctx); LLAMA_API float * llama_get_logits(struct llama_context * ctx);
// Logits for the ith token. For positive indices, Equivalent to: // Logits for the ith token. For positive indices, Equivalent to:
@ -973,6 +974,7 @@ extern "C" {
// in the order they have appeared in the batch. // in the order they have appeared in the batch.
// shape: [n_outputs*n_embd] // shape: [n_outputs*n_embd]
// Otherwise, returns NULL. // Otherwise, returns NULL.
// TODO: deprecate in favor of llama_get_embeddings_ith() (ref: https://github.com/ggml-org/llama.cpp/pull/14853#issuecomment-3113143522)
LLAMA_API float * llama_get_embeddings(struct llama_context * ctx); LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
// Get the embeddings for the ith token. For positive indices, Equivalent to: // Get the embeddings for the ith token. For positive indices, Equivalent to:

View file

@ -1933,12 +1933,6 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
} }
}, },
{
LLM_ARCH_UNKNOWN,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
},
},
{ {
LLM_ARCH_DREAM, LLM_ARCH_DREAM,
{ {
@ -1956,6 +1950,12 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
}, },
}, },
{
LLM_ARCH_UNKNOWN,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
},
},
}; };
static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = { static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {

View file

@ -718,11 +718,10 @@ int32_t llm_chat_apply_template(
} }
ss << message->content << "<|im_end|>"; ss << message->content << "<|im_end|>";
}
if (add_ass) { if (add_ass) {
ss << "<|im_assistant|>assistant<|im_middle|>"; ss << "<|im_assistant|>assistant<|im_middle|>";
} }
}
} else { } else {
// template not supported // template not supported
return -1; return -1;

View file

@ -508,12 +508,16 @@ enum llama_pooling_type llama_context::pooling_type() const {
} }
float * llama_context::get_logits() { float * llama_context::get_logits() {
output_reorder();
return logits; return logits;
} }
float * llama_context::get_logits_ith(int32_t i) { float * llama_context::get_logits_ith(int32_t i) {
int64_t j = -1; int64_t j = -1;
output_reorder();
try { try {
if (logits == nullptr) { if (logits == nullptr) {
throw std::runtime_error("no logits"); throw std::runtime_error("no logits");
@ -550,12 +554,16 @@ float * llama_context::get_logits_ith(int32_t i) {
} }
float * llama_context::get_embeddings() { float * llama_context::get_embeddings() {
output_reorder();
return embd; return embd;
} }
float * llama_context::get_embeddings_ith(int32_t i) { float * llama_context::get_embeddings_ith(int32_t i) {
int64_t j = -1; int64_t j = -1;
output_reorder();
try { try {
if (embd == nullptr) { if (embd == nullptr) {
throw std::runtime_error("no embeddings"); throw std::runtime_error("no embeddings");
@ -970,6 +978,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
// TODO: this clear of the buffer can easily be forgotten - need something better // TODO: this clear of the buffer can easily be forgotten - need something better
embd_seq.clear(); embd_seq.clear();
output_swaps.clear();
bool did_optimize = false; bool did_optimize = false;
@ -1189,9 +1198,6 @@ int llama_context::decode(const llama_batch & batch_inp) {
// make the outputs have the same order they had in the user-provided batch // make the outputs have the same order they had in the user-provided batch
// note: this is mostly relevant for recurrent models atm // note: this is mostly relevant for recurrent models atm
if (!sorted_output) { if (!sorted_output) {
const uint32_t n_vocab = model.vocab.n_tokens();
const uint64_t n_embd = model.hparams.n_embd;
GGML_ASSERT((size_t) n_outputs == out_ids.size()); GGML_ASSERT((size_t) n_outputs == out_ids.size());
// TODO: is there something more efficient which also minimizes swaps? // TODO: is there something more efficient which also minimizes swaps?
@ -1207,16 +1213,9 @@ int llama_context::decode(const llama_batch & batch_inp) {
continue; continue;
} }
std::swap(out_ids[i], out_ids[j_min]); std::swap(out_ids[i], out_ids[j_min]);
if (logits_size > 0) {
for (uint32_t k = 0; k < n_vocab; k++) { // remember the swaps and apply them lazily upon logits/embeddings access
std::swap(logits[i*n_vocab + k], logits[j_min*n_vocab + k]); output_swaps.push_back({ i, j_min });
}
}
if (embd_size > 0) {
for (uint32_t k = 0; k < n_embd; k++) {
std::swap(embd[i*n_embd + k], embd[j_min*n_embd + k]);
}
}
} }
std::fill(output_ids.begin(), output_ids.end(), -1); std::fill(output_ids.begin(), output_ids.end(), -1);
@ -1307,6 +1306,30 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
return n_outputs_max; return n_outputs_max;
} }
void llama_context::output_reorder() {
const uint32_t n_vocab = model.vocab.n_tokens();
const uint64_t n_embd = model.hparams.n_embd;
for (uint32_t s = 0; s < output_swaps.size(); ++s) {
const uint32_t i0 = output_swaps[s].i0;
const uint32_t i1 = output_swaps[s].i1;
if (logits_size > 0) {
for (uint32_t k = 0; k < n_vocab; k++) {
std::swap(logits[i0*n_vocab + k], logits[i1*n_vocab + k]);
}
}
if (embd_size > 0) {
for (uint32_t k = 0; k < n_embd; k++) {
std::swap(embd[i0*n_embd + k], embd[i1*n_embd + k]);
}
}
}
output_swaps.clear();
}
// //
// graph // graph
// //

View file

@ -181,6 +181,8 @@ private:
// Returns max number of outputs for which space was reserved. // Returns max number of outputs for which space was reserved.
uint32_t output_reserve(int32_t n_outputs); uint32_t output_reserve(int32_t n_outputs);
void output_reorder();
// //
// graph // graph
// //
@ -250,6 +252,13 @@ private:
std::vector<int32_t> output_ids; // map batch token positions to ids of the logits and embd buffers std::vector<int32_t> output_ids; // map batch token positions to ids of the logits and embd buffers
struct swap_info {
uint32_t i0;
uint32_t i1;
};
std::vector<swap_info> output_swaps;
ggml_backend_sched_ptr sched; ggml_backend_sched_ptr sched;
ggml_backend_t backend_cpu = nullptr; ggml_backend_t backend_cpu = nullptr;

View file

@ -768,6 +768,8 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std::
// Iterate and write all the keys first, each row is a cell // Iterate and write all the keys first, each row is a cell
// Get whole range at a time // Get whole range at a time
for (uint32_t il = 0; il < n_layer; ++il) { for (uint32_t il = 0; il < n_layer; ++il) {
// skip null layers (read_data will handle this by checking "r_l" and "s_l" for null)
if (r_l[il] == nullptr) continue;
// Write key type // Write key type
const int32_t r_type_i = (int32_t)r_l[il]->type; const int32_t r_type_i = (int32_t)r_l[il]->type;
@ -787,6 +789,8 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std::
if (!s_trans) { if (!s_trans) {
for (uint32_t il = 0; il < n_layer; ++il) { for (uint32_t il = 0; il < n_layer; ++il) {
// skip null layers (read_data will handle this by checking "r_l" and "s_l" for null)
if (s_l[il] == nullptr) continue;
// Write value type // Write value type
const int32_t s_type_i = (int32_t)s_l[il]->type; const int32_t s_type_i = (int32_t)s_l[il]->type;
@ -807,6 +811,9 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std::
// When v is transposed, we also need the element size and get the element ranges from each row // When v is transposed, we also need the element size and get the element ranges from each row
const uint32_t mem_size = size; const uint32_t mem_size = size;
for (uint32_t il = 0; il < n_layer; ++il) { for (uint32_t il = 0; il < n_layer; ++il) {
// skip null layers (read_data will handle this by checking "r_l" and "s_l" for null)
if (s_l[il] == nullptr) continue;
const uint32_t n_embd_s = hparams.n_embd_s(); const uint32_t n_embd_s = hparams.n_embd_s();
// Write value type // Write value type
@ -951,6 +958,8 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
// For each layer, read the keys for each cell, one row is one cell, read as one contiguous block // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
for (uint32_t il = 0; il < n_layer; ++il) { for (uint32_t il = 0; il < n_layer; ++il) {
// skip null layers
if (r_l[il] == nullptr) continue;
// Read type of key // Read type of key
int32_t r_type_i_ref; int32_t r_type_i_ref;
@ -978,11 +987,14 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
if (!s_trans) { if (!s_trans) {
for (uint32_t il = 0; il < n_layer; ++il) { for (uint32_t il = 0; il < n_layer; ++il) {
// skip null layers
if (s_l[il] == nullptr) continue;
// Read type of value // Read type of value
int32_t s_type_i_ref; int32_t s_type_i_ref;
io.read_to(&s_type_i_ref, sizeof(s_type_i_ref)); io.read_to(&s_type_i_ref, sizeof(s_type_i_ref));
const int32_t s_type_i = (int32_t)s_l[il]->type; const int32_t s_type_i = (int32_t)s_l[il]->type;
if (s_type_i != s_type_i_ref) { if (s_type_i != s_type_i_ref) {
LLAMA_LOG_ERROR("%s: mismatched s type (%d != %d, layer %d)\n", __func__, s_type_i, s_type_i_ref, il); LLAMA_LOG_ERROR("%s: mismatched s type (%d != %d, layer %d)\n", __func__, s_type_i, s_type_i_ref, il);
return false; return false;
@ -1005,6 +1017,9 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
} else { } else {
// For each layer, read the values for each cell (transposed) // For each layer, read the values for each cell (transposed)
for (uint32_t il = 0; il < n_layer; ++il) { for (uint32_t il = 0; il < n_layer; ++il) {
// skip null layers
if (s_l[il] == nullptr) continue;
const uint32_t n_embd_s = hparams.n_embd_s(); const uint32_t n_embd_s = hparams.n_embd_s();
// Read type of value // Read type of value

View file

@ -651,6 +651,9 @@ void llama_model::load_hparams(llama_model_loader & ml) {
ml.get_key(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale); ml.get_key(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale);
ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale);
// MiniCPM uses rope by default, unlike Granite which uses it as a switch
hparams.rope_finetuned = true;
switch (hparams.n_layer) { switch (hparams.n_layer) {
case 52: type = LLM_TYPE_1B; break; case 52: type = LLM_TYPE_1B; break;
case 40: type = LLM_TYPE_2B; break; case 40: type = LLM_TYPE_2B; break;
@ -1549,7 +1552,11 @@ void llama_model::load_hparams(llama_model_loader & ml) {
ml.get_key(LLM_KV_TOKEN_SHIFT_COUNT, hparams.token_shift_count, false); ml.get_key(LLM_KV_TOKEN_SHIFT_COUNT, hparams.token_shift_count, false);
switch (hparams.n_layer) { switch (hparams.n_layer) {
case 12: type = LLM_TYPE_190M; break; case 12:
switch (hparams.n_embd) {
case 768: type = LLM_TYPE_190M; break;
default: type = LLM_TYPE_UNKNOWN;
} break;
case 24: case 24:
switch (hparams.n_embd) { switch (hparams.n_embd) {
case 1024: type = LLM_TYPE_450M; break; case 1024: type = LLM_TYPE_450M; break;
@ -1562,7 +1569,17 @@ void llama_model::load_hparams(llama_model_loader & ml) {
case 3584: type = LLM_TYPE_7B; break; case 3584: type = LLM_TYPE_7B; break;
default: type = LLM_TYPE_UNKNOWN; default: type = LLM_TYPE_UNKNOWN;
} break; } break;
case 32: type = LLM_TYPE_2_9B; break; // RWKV-7-World case 32:
switch (hparams.n_embd) {
case 2560: type = LLM_TYPE_2_9B; break;
case 4096: type = LLM_TYPE_7B; break;
default: type = LLM_TYPE_UNKNOWN;
} break;
case 61:
switch (hparams.n_embd) {
case 4096: type = LLM_TYPE_14B; break;
default: type = LLM_TYPE_UNKNOWN;
} break;
default: type = LLM_TYPE_UNKNOWN; default: type = LLM_TYPE_UNKNOWN;
} }
} break; } break;