Merge branch 'upstream' into concedo_experimental

# Conflicts:
#	ggml/src/ggml-opencl/ggml-opencl.cpp
#	ggml/src/ggml-sycl/ggml-sycl.cpp
#	requirements/requirements-convert_hf_to_gguf_update.txt
#	scripts/compare-llama-bench.py
#	tests/test-backend-ops.cpp
#	tests/test-chat.cpp
#	tools/imatrix/README.md
#	tools/imatrix/imatrix.cpp
#	tools/llama-bench/llama-bench.cpp
This commit is contained in:
Concedo 2025-08-04 22:42:02 +08:00
commit 8bd0a560f0
23 changed files with 512 additions and 210 deletions

View file

@ -2649,6 +2649,15 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.n_out_freq = value;
}
).set_examples({LLAMA_EXAMPLE_IMATRIX}));
add_opt(common_arg(
{"--output-format"}, "{gguf,dat}",
string_format("output format for imatrix file (default: %s)", params.imat_dat ? "dat" : "gguf"),
[](common_params & params, const std::string & value) {
/**/ if (value == "gguf") { params.imat_dat = false; }
else if (value == "dat") { params.imat_dat = true; }
else { throw std::invalid_argument("invalid output format"); }
}
).set_examples({LLAMA_EXAMPLE_IMATRIX}));
add_opt(common_arg(
{"--save-frequency"}, "N",
string_format("save an imatrix copy every N iterations (default: %d)", params.n_save_freq),

View file

@ -1646,7 +1646,7 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) {
"|<function name=\"([^\"]+)\">" // match 5 (function name again)
);
if (auto res = builder.try_find_regex(open_regex)) {
while (auto res = builder.try_find_regex(open_regex)) {
const auto & block_start = res->groups[1];
std::string block_end = block_start.empty() ? "" : "```";
@ -1668,7 +1668,6 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) {
builder.consume_literal(block_end);
builder.consume_spaces();
}
builder.add_content(builder.consume_rest());
} else {
throw common_chat_msg_partial_exception("failed to parse tool call");
}
@ -1693,11 +1692,10 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) {
builder.consume_spaces();
}
}
builder.add_content(builder.consume_rest());
}
} else {
builder.add_content(builder.consume_rest());
}
builder.add_content(builder.consume_rest());
}
static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) {

View file

@ -435,6 +435,7 @@ struct common_params {
int32_t n_out_freq = 10; // output the imatrix every n_out_freq iterations
int32_t n_save_freq = 0; // save the imatrix every n_save_freq iterations
int32_t i_chunk = 0; // start processing from this chunk
bool imat_dat = false; // whether the legacy imatrix.dat format should be output
bool process_output = false; // collect data for the output tensor
bool compute_ppl = true; // whether to compute perplexity

View file

@ -702,6 +702,9 @@ class TextModel(ModelBase):
if chkhsh == "81212dc7cdb7e0c1074ca62c5aeab0d43c9f52b8a737be7b12a777c953027890":
# ref: https://huggingface.co/moonshotai/Kimi-K2-Base
res = "kimi-k2"
if chkhsh == "d4540891389ea895b53b399da6ac824becc30f2fba0e9ddbb98f92e55ca0e97c":
# ref: https://huggingface.co/Qwen/Qwen3-Embedding-0.6B
res = "qwen2"
if chkhsh == "0ef9807a4087ebef797fc749390439009c3b9eda9ad1a097abbe738f486c01e5":
# ref: https://huggingface.co/meta-llama/Meta-Llama-3-8B
res = "llama-bpe"
@ -849,6 +852,9 @@ class TextModel(ModelBase):
if chkhsh == "2085e1638f6c377a0aa4ead21b27bb4cb941bf800df86ed391011769c1758dfb":
# ref: https://huggingface.co/LGAI-EXAONE/EXAONE-4.0-32B
res = "exaone4"
if chkhsh == "a1e163ecab2e718a4c829d1148b6e86824ec36163bb71941c3dca9cd5ac25756":
# ref: https://huggingface.co/JetBrains/Mellum-4b-base
res = "mellum"
if res is None:
logger.warning("\n")
@ -6056,6 +6062,7 @@ class DeepseekModel(TextModel):
@ModelBase.register("DeepseekV2ForCausalLM")
@ModelBase.register("DeepseekV3ForCausalLM")
@ModelBase.register("KimiVLForConditionalGeneration")
class DeepseekV2Model(TextModel):
model_arch = gguf.MODEL_ARCH.DEEPSEEK2
@ -6158,6 +6165,13 @@ class DeepseekV2Model(TextModel):
_experts: list[dict[str, Tensor]] | None = None
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# skip vision tensors and remove "language_model." for Kimi-VL
if "vision_tower" in name or "multi_modal_projector" in name:
return []
if name.startswith("language_model."):
name = name.replace("language_model.", "")
# rename e_score_correction_bias tensors
if name.endswith("e_score_correction_bias"):
name = name.replace("e_score_correction_bias", "e_score_correction.bias")

View file

@ -59,6 +59,10 @@ parser.add_argument(
"--full", action="store_true",
help="download full list of models - make sure you have access to all of them",
)
parser.add_argument(
"--check-missing", action="store_true",
help="only check for missing pre-tokenizer hashes",
)
parser.add_argument(
"hf_token",
help="optional HF token",
@ -70,6 +74,10 @@ hf_token = args.hf_token if args.hf_token is not None else hf_token
if hf_token is None:
logger.warning("HF token not found. You can provide it as an argument or set it in ~/.cache/huggingface/token")
if args.check_missing and args.full:
logger.warning("Downloading full list of models requested, ignoring --check-missing!")
args.check_missing = False
# TODO: this string has to exercise as much pre-tokenizer functionality as possible
# will be updated with time - contributions welcome
CHK_TXT = '\n \n\n \n\n\n \t \t\t \t\n \n \n \n \n🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天 ------======= нещо на Български \'\'\'\'\'\'```````\"\"\"\"......!!!!!!?????? I\'ve been \'told he\'s there, \'RE you sure? \'M not sure I\'ll make it, \'D you like some tea? We\'Ve a\'lL'
@ -130,6 +138,7 @@ models = [
{"name": "midm-2.0", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/K-intelligence/Midm-2.0-Base-Instruct", },
{"name": "lfm2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LiquidAI/LFM2-Tokenizer"},
{"name": "exaone4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/EXAONE-4.0-32B", },
{"name": "mellum", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/JetBrains/Mellum-4b-base", },
]
# some models are known to be broken upstream, so we will skip them as exceptions
@ -147,6 +156,7 @@ pre_computed_hashes = [
{"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-7B-Base", "chkhsh": "3eda48b4c4dc7de733d1a8b3e3b4a85243dbbf704da2ee9d42c6beced8897896"},
{"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-34B-Base", "chkhsh": "48f8e02c0359c0bbdd82f26909171fac1c18a457bb47573ed1fe3bbb2c1cfd4b"},
{"name": "kimi-k2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/moonshotai/Kimi-K2-Base", "chkhsh": "81212dc7cdb7e0c1074ca62c5aeab0d43c9f52b8a737be7b12a777c953027890"},
{"name": "qwen2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen3-Embedding-0.6B", "chkhsh": "d4540891389ea895b53b399da6ac824becc30f2fba0e9ddbb98f92e55ca0e97c"},
]
@ -221,12 +231,13 @@ if not args.full:
all_models = models.copy()
models = [model for model in all_models if model["name"] not in existing_models]
logging.info(f"Downloading {len(models)} models...")
for model in models:
try:
download_model(model)
except Exception as e:
logger.error(f"Failed to download model {model['name']}. Error: {e}")
if not args.check_missing:
logging.info(f"Downloading {len(models)} models...")
for model in models:
try:
download_model(model)
except Exception as e:
logger.error(f"Failed to download model {model['name']}. Error: {e}")
# generate the source code for the convert_hf_to_gguf.py:get_vocab_base_pre() function:

View file

@ -315,8 +315,9 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations
const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16;
const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies &&
(Q->ne[3] > 1 || cc < GGML_CUDA_CC_ADA_LOVELACE) && !mma_needs_data_conversion;
const bool mma_faster_for_rtx4000 = Q->ne[3] > 1 || (Q->ne[2] > 4*K->ne[2] && K->ne[1] >= 8192);
const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && !mma_needs_data_conversion &&
(cc < GGML_CUDA_CC_ADA_LOVELACE || mma_faster_for_rtx4000);
const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % (2*warp_size) == 0;
if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) {
if (prec == GGML_PREC_DEFAULT) {

View file

@ -1853,6 +1853,9 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
ggml_cuda_pool_alloc<cuda_t> src0_alloc(ctx.pool());
ggml_cuda_pool_alloc<cuda_t> src1_alloc(ctx.pool());
bool is_src0_cont_2 = ggml_is_contiguous_2(src0);
bool is_src1_cont_2 = ggml_is_contiguous_2(src1);
// Handle src0
src0_ptr = (const cuda_t *) src0->data;
@ -1871,6 +1874,8 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
s11 = ne10;
s12 = ne11*s11;
s13 = ne12*s12;
is_src1_cont_2 = true;
}
// Setup destination buffer
@ -1919,15 +1924,19 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
const int64_t r2 = ne12/ne02;
const int64_t r3 = ne13/ne03;
if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
if (r2 == 1 && r3 == 1 && is_src0_cont_2 && is_src1_cont_2) {
// with a [0, 2, 1, 3] perm. and ne02==1 the matrix strides need to be determined from dim 3:
const int64_t sma = ne02 == 1 ? nb03/nb00 : nb02/nb00;
const int64_t smb = ne12 == 1 ? s13 : s12;
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
// use cublasGemmStridedBatchedEx
CUBLAS_CHECK(
cublasGemmStridedBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10,
alpha, src0_ptr, cu_data_type_a, nb01/nb00, nb02/nb00, // strideA
src1_ptr, cu_data_type_b, s11, s12, // strideB
beta, dst_t, cu_data_type, ne0, ne1*ne0, // strideC
alpha, src0_ptr, cu_data_type_a, nb01/nb00, sma, // strideA
src1_ptr, cu_data_type_b, s11, smb, // strideB
beta, dst_t, cu_data_type, ne0, ne1*ne0, // strideC
ne12*ne13,
cu_compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));

View file

@ -1,65 +1,75 @@
#include "im2col.cuh"
#define MIN(a, b) (a) < (b) ? (a) : (b)
#define MAX_GRIDDIM_Z 65535
template <typename T>
static __global__ void im2col_kernel(
const float * x, T * dst, int64_t batch_offset,
int64_t offset_delta, int64_t IC, int64_t IW, int64_t IH, int64_t OH, int64_t OW, int64_t KW, int64_t KH, int64_t pelements, int64_t CHW,
const float * x, T * dst,
int64_t IC, int64_t IW, int64_t IH, int64_t OH, int64_t OW, int64_t KW, int64_t KH,
int64_t IC_IH_IW, int64_t IH_IW, int64_t N_OH, int64_t KH_KW, int64_t IC_KH_KW,
int s0, int s1, int p0, int p1, int d0, int d1) {
const int64_t i = threadIdx.x + blockIdx.x * blockDim.x;
if (i >= pelements) {
if (i >= IC_KH_KW) {
return;
}
const int64_t ksize = OW * KH;
const int64_t kx = i / ksize;
const int64_t kd = kx * ksize;
const int64_t ky = (i - kd) / OW;
const int64_t ix = i % OW;
const int64_t iic = i / (KH_KW);
const int64_t rem = i - iic * KH_KW;
const int64_t ikh = rem / KW;
const int64_t ikw = rem - ikh * KW;
const int64_t oh = blockIdx.y;
const int64_t batch = blockIdx.z / IC;
const int64_t ic = blockIdx.z % IC;
const int64_t iow = blockIdx.y;
for (int64_t iz = blockIdx.z; iz < N_OH; iz+=MAX_GRIDDIM_Z) {
const int64_t in = iz / OH;
const int64_t ioh = iz - in * OH;
const int64_t iiw = ix * s0 + kx * d0 - p0;
const int64_t iih = oh * s1 + ky * d1 - p1;
const int64_t iiw = iow * s0 + ikw * d0 - p0;
const int64_t iih = ioh * s1 + ikh * d1 - p1;
const int64_t offset_dst =
((batch * OH + oh) * OW + ix) * CHW +
(ic * (KW * KH) + ky * KW + kx);
const int64_t offset_dst =
((in * OH + ioh) * OW + iow) * IC_KH_KW + iic * KH_KW + ikh * KW + ikw;
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
dst[offset_dst] = 0.0f;
} else {
const int64_t offset_src = ic * offset_delta + batch * batch_offset;
dst[offset_dst] = x[offset_src + iih * IW + iiw];
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
dst[offset_dst] = 0.0f;
} else {
const int64_t offset_src = iic * IC_IH_IW + in * IH_IW;
dst[offset_dst] = x[offset_src + iih * IW + iiw];
}
}
}
// im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
template <typename T>
static void im2col_cuda(const float * x, T* dst,
int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, int64_t KH, int64_t IC,
int64_t batch, int64_t batch_offset, int64_t offset_delta,
int64_t N, int64_t IC_IH_IW, int64_t IH_IW,
int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
const int parallel_elements = OW * KW * KH;
const int num_blocks = (parallel_elements + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE;
dim3 block_nums(num_blocks, OH, batch * IC);
im2col_kernel<<<block_nums, CUDA_IM2COL_BLOCK_SIZE, 0, stream>>>(x, dst, batch_offset, offset_delta, IC, IW, IH, OH, OW, KW, KH, parallel_elements, (IC * KH * KW), s0, s1, p0, p1, d0, d1);
const int64_t IC_KH_KW = IC * KH * KW;
const int64_t num_blocks = (IC_KH_KW + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE;
const int64_t N_OH = N * OH;
const int64_t KH_KW = KW*KH;
dim3 block_nums(num_blocks, OW, MIN(N_OH, MAX_GRIDDIM_Z));
im2col_kernel<<<block_nums, MIN(IC_KH_KW, CUDA_IM2COL_BLOCK_SIZE) , 0, stream>>>(x, dst, IC, IW, IH, OH, OW, KW, KH,
IC_IH_IW, IH_IW, N_OH, KH_KW, IC_KH_KW,
s0, s1, p0, p1, d0, d1);
}
static void im2col_cuda_f16(const float * x, half * dst,
int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, int64_t KH, int64_t IC,
int64_t batch, int64_t batch_offset, int64_t offset_delta,
int64_t N, int64_t IC_IH_IW, int64_t IH_IW,
int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
im2col_cuda<half>(x, dst, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, offset_delta, s0, s1, p0, p1, d0, d1, stream);
im2col_cuda<half>(x, dst, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream);
}
static void im2col_cuda_f32(const float * x, float * dst,
int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, int64_t KH, int64_t IC,
int64_t batch, int64_t batch_offset, int64_t offset_delta,
int64_t N, int64_t IC_IH_IW, int64_t IH_IW,
int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
im2col_cuda<float>(x, dst, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, offset_delta, s0, s1, p0, p1, d0, d1, stream);
im2col_cuda<float>(x, dst, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream);
}
void ggml_cuda_op_im2col(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@ -91,13 +101,13 @@ void ggml_cuda_op_im2col(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const int64_t OH = is_2D ? dst->ne[2] : 1;
const int64_t OW = dst->ne[1];
const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
const int64_t batch = src1->ne[is_2D ? 3 : 2];
const size_t batch_offset = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32
const int64_t IC_IH_IW = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
const int64_t N = src1->ne[is_2D ? 3 : 2];
const int64_t IH_IW = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32
if(dst->type == GGML_TYPE_F16) {
im2col_cuda_f16(src1_d, (half *) dst_d, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, stream);
im2col_cuda_f16(src1_d, (half *) dst_d, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream);
} else {
im2col_cuda_f32(src1_d, (float *) dst_d, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, stream);
im2col_cuda_f32(src1_d, (float *) dst_d, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream);
}
}

View file

@ -238,6 +238,7 @@ enum vk_device_architecture {
AMD_RDNA2,
AMD_RDNA3,
INTEL_XE2,
NVIDIA_PRE_TURING,
};
// HSK x HSV
@ -331,10 +332,33 @@ static vk_device_architecture get_device_architecture(const vk::PhysicalDevice&
// https://www.intel.com/content/www/us/en/docs/oneapi/optimization-guide-gpu/2025-0/intel-xe-gpu-architecture.html
return vk_device_architecture::INTEL_XE2;
}
} else if (props.vendorID == VK_VENDOR_ID_NVIDIA) {
const std::vector<vk::ExtensionProperties> ext_props = device.enumerateDeviceExtensionProperties();
bool cooperative_matrix = false;
// Detect "pre-turing" based on lack of coopmat support.
for (const auto& properties : ext_props) {
if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0) {
cooperative_matrix = true;
break;
}
}
if (!cooperative_matrix) {
return vk_device_architecture::NVIDIA_PRE_TURING;
}
}
return vk_device_architecture::OTHER;
}
enum vk_conv_shapes {
CONV_SHAPE_128x128,
CONV_SHAPE_64x32,
CONV_SHAPE_32x256,
CONV_SHAPE_COUNT,
};
struct vk_device_struct {
std::recursive_mutex mutex;
@ -499,8 +523,8 @@ struct vk_device_struct {
vk_pipeline pipeline_rwkv_wkv6_f32;
vk_pipeline pipeline_rwkv_wkv7_f32;
vk_pipeline pipeline_opt_step_adamw_f32;
vk_pipeline pipeline_conv2d_f32;
vk_pipeline pipeline_conv2d_f16_f32;
vk_pipeline pipeline_conv2d_f32[CONV_SHAPE_COUNT];
vk_pipeline pipeline_conv2d_f16_f32[CONV_SHAPE_COUNT];
vk_pipeline pipeline_conv2d_dw_whcn_f32;
vk_pipeline pipeline_conv2d_dw_cwhn_f32;
@ -924,8 +948,22 @@ struct vk_op_conv2d_push_constants {
uint32_t nb1;
uint32_t nb2;
uint32_t nb3;
// init_fastdiv_values constants for dividing by KW, KW*KH, OW, OW*OH
uint32_t KWmp; uint32_t KWL;
uint32_t KWKHmp; uint32_t KWKHL;
uint32_t OWmp; uint32_t OWL;
uint32_t OWOHmp; uint32_t OWOHL;
};
template <> void init_pushconst_fastdiv(vk_op_conv2d_push_constants &p) {
// Compute magic values to divide by KW, KW*KH, OW, OW*OH
init_fastdiv_values(p.KW, p.KWmp, p.KWL);
init_fastdiv_values(p.KW*p.KH, p.KWKHmp, p.KWKHL);
init_fastdiv_values(p.OW, p.OWmp, p.OWL);
init_fastdiv_values(p.OW*p.OH, p.OWOHmp, p.OWOHL);
}
struct vk_op_conv2d_dw_push_constants {
uint32_t ne;
uint32_t batches;
@ -2084,12 +2122,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
s_mmq_wg_denoms = { 32, 64, 1 };
// spec constants and tile sizes for quant matmul (Qi_K)
l_warptile_mmq_k = { 256, 64, 128, 64, 1 };
m_warptile_mmq_k = { 256, 32, 64, 64, 0 };
s_warptile_mmq_k = { 256, 32, 32, 128, 0 };
l_mmq_wg_denoms_k = { 64, 128, 1 };
m_mmq_wg_denoms_k = { 32, 64, 1 };
s_mmq_wg_denoms_k = { 32, 32, 1 };
l_warptile_mmq_k = { 256, 128, 256, 64, 1 };
m_warptile_mmq_k = { 256, 128, 128, 64, 1 };
s_warptile_mmq_k = { 256, 32, 64, 128, 0 };
l_mmq_wg_denoms_k = { 128, 256, 1 };
m_mmq_wg_denoms_k = { 128, 128, 1 };
s_mmq_wg_denoms_k = { 32, 64, 1 };
// spec constants and tile sizes for quant matmul_id
l_warptile_mmqid = { 256, 128, 128, 16, 0 };
@ -2863,7 +2901,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32[i], "mul_mat_vec_p021_f16_f32"+std::to_string(i+1), mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {device->subgroup_size, i + 1}, 1, true);
}
}
ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 9 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 12 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
@ -3064,48 +3102,105 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
// conv2d
uint32_t conv2d_WG_SIZE = 256;
uint32_t conv2d_BS_K = 128;
uint32_t conv2d_BS_CRS = 16;
uint32_t use_collectives = 0; // Enables subgroup ops for preventing the re-calculation of indices.
if (device->subgroup_shuffle &&
device->vendor_id != VK_VENDOR_ID_INTEL) { // Do not enable collectives on Intel, see PR 14316
use_collectives = 1;
conv2d_BS_CRS = std::min(
device->subgroup_size,
conv2d_BS_CRS); // CRS block size should be capped at sugroup size for correctness when shuffle is used.
}
uint32_t conv2d_BS_NPQ = 128;
uint32_t conv2d_TS_K = 8;
uint32_t conv2d_shmem_req =
(conv2d_BS_K * (conv2d_BS_CRS + 1) + conv2d_BS_CRS * (conv2d_BS_NPQ + 1)) * sizeof(float);
if (device->properties.limits.maxComputeSharedMemorySize < conv2d_shmem_req) {
conv2d_BS_CRS = 8;
if (use_collectives) {
conv2d_BS_CRS = std::min(device->subgroup_size, conv2d_BS_CRS);
}
}
for (uint32_t s = 0; s < CONV_SHAPE_COUNT; ++s) {
uint32_t conv2d_WG_SIZE = 256;
uint32_t conv2d_BS_K = 128;
uint32_t conv2d_BS_CRS = 16;
uint32_t use_collectives = 0; // Enables subgroup ops for preventing the re-calculation of indices.
uint32_t conv2d_BS_NPQ = 128;
uint32_t conv2d_TS_K = 8;
uint32_t conv2d_SHMEM_PAD = 4;
bool conv2d_UNROLL = true;
if (use_collectives) {
ggml_vk_create_pipeline(
device, device->pipeline_conv2d_f32, "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3,
sizeof(vk_op_conv2d_push_constants), { conv2d_BS_K, conv2d_BS_NPQ, 1 },
{ conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives }, 1, true, true);
ggml_vk_create_pipeline(
device, device->pipeline_conv2d_f16_f32, "conv2d_f16_f32", conv2d_f16_f32_len, conv2d_f16_f32_data, "main", 3,
sizeof(vk_op_conv2d_push_constants), { conv2d_BS_K, conv2d_BS_NPQ, 1 },
{ conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives }, 1, true, true);
} else {
ggml_vk_create_pipeline(
device, device->pipeline_conv2d_f32, "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3,
sizeof(vk_op_conv2d_push_constants), { conv2d_BS_K, conv2d_BS_NPQ, 1 },
{ conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives }, 1, true,
false);
ggml_vk_create_pipeline(
device, device->pipeline_conv2d_f16_f32, "conv2d_f16_f32", conv2d_f16_f32_len, conv2d_f16_f32_data, "main", 3,
sizeof(vk_op_conv2d_push_constants), { conv2d_BS_K, conv2d_BS_NPQ, 1 },
{ conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives }, 1, true,
false);
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
if (device->coopmat2) {
conv2d_SHMEM_PAD = 8; // 8 float16_t
}
#endif
if (device->vendor_id == VK_VENDOR_ID_INTEL) {
conv2d_SHMEM_PAD = 0;
conv2d_UNROLL = false;
} else if (device->vendor_id == VK_VENDOR_ID_AMD) {
conv2d_SHMEM_PAD = device->architecture == vk_device_architecture::AMD_GCN ? 1 : 4;
}
switch (s) {
default:
case CONV_SHAPE_128x128:
conv2d_BS_K = 128;
conv2d_BS_NPQ = 128;
conv2d_BS_CRS = 16;
if (device->vendor_id == VK_VENDOR_ID_AMD && device->architecture != vk_device_architecture::AMD_GCN) {
conv2d_UNROLL = false;
}
break;
case CONV_SHAPE_64x32:
conv2d_BS_K = 64;
conv2d_BS_NPQ = 32;
conv2d_BS_CRS = 32;
conv2d_TS_K = 4;
break;
case CONV_SHAPE_32x256:
conv2d_BS_K = 32;
conv2d_BS_NPQ = 256;
conv2d_BS_CRS = 16;
break;
}
// Use collectives on pre-Turing NVIDIA GPUs and GCN AMD cards, which had slower integer math.
bool allow_collectives_nv = device->vendor_id != VK_VENDOR_ID_NVIDIA ||
device->architecture == vk_device_architecture::NVIDIA_PRE_TURING;
bool allow_collectives_amd = device->vendor_id != VK_VENDOR_ID_AMD ||
device->architecture == vk_device_architecture::AMD_GCN;
if (device->subgroup_shuffle &&
device->vendor_id != VK_VENDOR_ID_INTEL && // Do not enable collectives on Intel, see PR 14316.
allow_collectives_nv &&
allow_collectives_amd) {
use_collectives = 1;
conv2d_BS_CRS = std::min(
device->subgroup_size,
conv2d_BS_CRS); // CRS block size should be capped at subgroup size for correctness when shuffle is used.
}
uint32_t conv2d_shmem_req =
(conv2d_BS_K * (conv2d_BS_CRS + conv2d_SHMEM_PAD) + conv2d_BS_CRS * (conv2d_BS_NPQ + conv2d_SHMEM_PAD)) * sizeof(float);
if (device->properties.limits.maxComputeSharedMemorySize < conv2d_shmem_req) {
conv2d_BS_CRS = 8;
if (use_collectives) {
conv2d_BS_CRS = std::min(device->subgroup_size, conv2d_BS_CRS);
}
}
std::array<uint32_t, 3> wg_denoms = { conv2d_BS_K, conv2d_BS_NPQ, 1 };
std::vector<uint32_t> spec_constants = { conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives, conv2d_SHMEM_PAD };
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
if (device->coopmat2) {
ggml_vk_create_pipeline(
device, device->pipeline_conv2d_f32[s], "conv2d_f32", conv2d_f32_cm2_len, conv2d_f32_cm2_data, "main", 3,
sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives);
ggml_vk_create_pipeline(
device, device->pipeline_conv2d_f16_f32[s], "conv2d_f16_f32", conv2d_f16_f32_cm2_len, conv2d_f16_f32_cm2_data, "main", 3,
sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives);
} else
#endif
if (conv2d_UNROLL) {
ggml_vk_create_pipeline(
device, device->pipeline_conv2d_f32[s], "conv2d_f32", conv2d_f32_unroll_len, conv2d_f32_unroll_data, "main", 3,
sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives);
ggml_vk_create_pipeline(
device, device->pipeline_conv2d_f16_f32[s], "conv2d_f16_f32", conv2d_f16_f32_unroll_len, conv2d_f16_f32_unroll_data, "main", 3,
sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives);
} else {
ggml_vk_create_pipeline(
device, device->pipeline_conv2d_f32[s], "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3,
sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives);
ggml_vk_create_pipeline(
device, device->pipeline_conv2d_f16_f32[s], "conv2d_f16_f32", conv2d_f16_f32_len, conv2d_f16_f32_data, "main", 3,
sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives);
}
}
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
@ -4967,26 +5062,37 @@ static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, siz
ggml_vk_queue_command_pools_cleanup(dst->device);
}
static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, int m, int n, int k, const vk_pipeline& pipeline) {
static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, uint32_t m, uint32_t n, uint32_t k, const vk_pipeline& pipeline) {
VK_LOG_DEBUG("ggml_vk_guess_split_k(" << m << ", " << n << ", " << k << ")");
uint32_t split_k = 1;
if (ctx->device->shader_core_count != 0 && m >= (int)pipeline->wg_denoms[0] && n >= (int)pipeline->wg_denoms[1]) {
if (ctx->device->shader_core_count != 0 && m >= pipeline->wg_denoms[0] && n >= pipeline->wg_denoms[1]) {
// If k is 'large' and the SMs will fill less than halfway, use split_k.
uint32_t m_tiles = CEIL_DIV(m, pipeline->wg_denoms[0]);
uint32_t n_tiles = CEIL_DIV(n, pipeline->wg_denoms[1]);
if (k >= 2048 && m_tiles * n_tiles < ctx->device->shader_core_count / 2) {
split_k = ctx->device->shader_core_count / (m_tiles * n_tiles);
// Clamp to 2 or 4
split_k = std::min(split_k, 4u);
if (split_k == 3) {
split_k = 2;
if (k >= 2048) {
if (m_tiles * n_tiles <= ctx->device->shader_core_count / 2) {
split_k = ctx->device->shader_core_count / (m_tiles * n_tiles);
} else if (m_tiles * n_tiles <= ctx->device->shader_core_count * 2 / 3) {
split_k = 3;
}
if (ctx->device->coopmat2) {
// coopmat2 shader expects splits to be aligned to 256
while (split_k > 1 && ((k / split_k) % 256) != 0) {
split_k /= 2;
// Cap the split at 8x. Unless k is huge this is a lot of overhead.
split_k = std::min(split_k, 8u);
// ggml_vk_matmul will align the splits to be a multiple of 256.
// If this rounded up size would cause the last split to be empty,
// then reduce the split count.
while (true) {
if (split_k == 1) {
break;
}
uint32_t k_split = CEIL_DIV(k, split_k);
k_split = ROUNDUP_POW2(k_split, 256);
if (k_split * (split_k - 1) < k) {
break;
}
split_k--;
}
}
}
@ -4998,9 +5104,22 @@ static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx,
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")");
if (ctx->device->coopmat2) {
const uint32_t shader_core_count = ctx->device->shader_core_count;
const uint32_t tiles_l = CEIL_DIV(m, mmp->a_l->wg_denoms[0]) * CEIL_DIV(n, mmp->a_l->wg_denoms[1]);
const uint32_t tiles_m = CEIL_DIV(m, mmp->a_m->wg_denoms[0]) * CEIL_DIV(n, mmp->a_m->wg_denoms[1]);
// Use large shader when the N dimension is greater than the medium shader's tile size
uint32_t crossover_large = mmp->m->wg_denoms[1];
if ((ctx->device->mul_mat_l[src0_type] && (n > crossover_large)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_s[src0_type])) {
// Prefer large over medium if either:
// - medium or large tiles would overfill the GPU
// - large tiles with a split_k==3 fits in the GPU and medium tiles with split_k==2 does not
// (medium with split_k==2 is probably better if it fits - more workgroups running and less split_k overhead)
bool prefer_large = tiles_m > shader_core_count || tiles_l > shader_core_count ||
// split_k==3 with large tiles likely better than medium tiles with no split_k.
(tiles_l <= shader_core_count / 3 && tiles_m > shader_core_count / 2);
if ((ctx->device->mul_mat_l[src0_type] && (n > crossover_large && prefer_large)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_s[src0_type])) {
return aligned ? mmp->a_l : mmp->l;
}
// Use medium shader when the N dimension is greater than the small shader's tile size
@ -5044,7 +5163,11 @@ static void ggml_vk_matmul(
GGML_ASSERT(batch_stride_d == m * n);
const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, CEIL_DIV(k, split_k), ne02, ne12, broadcast2, broadcast3, padded_n };
// Round the split size up to a multiple of 256 (k-quant alignment)
uint32_t k_split = CEIL_DIV(k, split_k);
k_split = ROUNDUP_POW2(k_split, 256);
const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k_split, ne02, ne12, broadcast2, broadcast3, padded_n };
// Make sure enough workgroups get assigned for split k to work
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, batch });
ggml_vk_sync_buffers(subctx);
@ -5766,7 +5889,7 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
const uint64_t ne00 = src0->ne[0];
const uint64_t ne01 = src0->ne[1];
const uint64_t ne02 = src0->ne[2];
// const uint64_t ne03 = src0->ne[3];
const uint64_t ne03 = src0->ne[3];
const uint64_t nb01 = src0->nb[1];
const uint64_t nb02 = src0->nb[2];
@ -5778,7 +5901,12 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
const uint64_t ne12 = src1->ne[2];
// const uint64_t ne13 = src1->ne[3];
const uint32_t nb03 = (uint32_t)(src0->nb[3] / sizeof(ggml_fp16_t));
const uint32_t nb13 = (uint32_t)(src1->nb[3] / sizeof(float));
const uint32_t nb23 = (uint32_t)(dst->nb[3] / sizeof(float));
GGML_ASSERT(ne11 == 1);
GGML_ASSERT(src0->ne[3] == src1->ne[3]); // checked in supports_op
ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context;
@ -5794,7 +5922,7 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
src1_uma = d_Qy != nullptr;
}
const uint64_t d_ne = ne01 * ne11 * ne12;
const uint64_t d_ne = ne01 * ne11 * ne12 * ne03;
const uint32_t row_stride_x = nb01 / sizeof(ggml_fp16_t);
const uint32_t channel_stride_x = nb02 / sizeof(ggml_fp16_t);
@ -5829,10 +5957,10 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
const uint64_t d_shader_offset = d_buf_offset - d_buffer_offset;
// compute
const std::array<uint32_t, 9> pc = { (uint32_t)ne00, (uint32_t)ne01, row_stride_x, channel_stride_x, channel_stride_y, (uint32_t)(ne12 / ne02), (uint32_t)ne12, (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)) };
const std::array<uint32_t, 12> pc = { (uint32_t)ne00, (uint32_t)ne01, row_stride_x, channel_stride_x, channel_stride_y, (uint32_t)(ne12 / ne02), (uint32_t)ne12, (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)), nb03, nb13, nb23 };
ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_nc_f16_f32,
{ vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, pc, { 1, (uint32_t)ne01, (uint32_t)ne12 });
{ vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, pc, { (uint32_t)ne03, (uint32_t)ne01, (uint32_t)ne12 });
}
static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
@ -6665,6 +6793,34 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
}
}
static std::array<uint32_t, 3> ggml_vk_get_conv_elements(const ggml_tensor *dst) {
const ggml_tensor *src0 = dst->src[0];
const ggml_tensor *src1 = dst->src[1];
// src0 - kernel: [KW, KH, Cin, Cout]
// src1 - input: [W, H, Cin, N]
// dst - result: [OW, OH, Cout, N]
// Copied from ggml.c: int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p, int d)
auto calc_conv_output_size = [](int64_t ins, int64_t ks, int s, int p, int d) -> int64_t {
return (ins + 2 * p - d * (ks - 1) - 1) / s + 1;
};
// parallelize in {OW/BS_K, OH/BS_NPQ, 1}
int64_t W = src1->ne[0];
int64_t H = src1->ne[1];
int64_t KW = src0->ne[0];
int64_t KH = src0->ne[1];
int64_t Cout = src0->ne[3];
int64_t N = src1->ne[3];
int64_t OH = calc_conv_output_size(H, KH, dst->op_params[1], dst->op_params[3], dst->op_params[5]);
int64_t OW = calc_conv_output_size(W, KW, dst->op_params[0], dst->op_params[2], dst->op_params[4]);
int64_t NPQ = N * OW * OH;
// Tile output matrix to (K/NB_K, NPQ/NB_NPQ, 1) workgroups
std::array<uint32_t, 3> elements = { static_cast<uint32_t>(Cout), static_cast<uint32_t>(NPQ), 1 };
return elements;
}
static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op) {
switch (op) {
case GGML_OP_GET_ROWS:
@ -6994,10 +7150,30 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
case GGML_OP_CONV_2D:
if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&
ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) {
auto elements = ggml_vk_get_conv_elements(dst);
vk_conv_shapes shape;
uint32_t tiles[CONV_SHAPE_COUNT];
for (uint32_t i = 0; i < CONV_SHAPE_COUNT; ++i) {
tiles[i] = CEIL_DIV(elements[0], ctx->device->pipeline_conv2d_f32[i]->wg_denoms[0]) * CEIL_DIV(elements[1], ctx->device->pipeline_conv2d_f32[i]->wg_denoms[1]);
}
// We can't query number of shader cores on Intel, use 32 as a placeholder
// so small convolutions will still choose a smaller tile.
const uint32_t shader_core_count = ctx->device->shader_core_count > 0 ? ctx->device->shader_core_count : 32;
if (elements[0] > 64 && tiles[CONV_SHAPE_128x128] >= shader_core_count * 2) {
shape = CONV_SHAPE_128x128;
} else if (elements[0] <= 32 && tiles[CONV_SHAPE_32x256] >= shader_core_count * 2) {
shape = CONV_SHAPE_32x256;
} else {
shape = CONV_SHAPE_64x32;
}
if (src0->type == GGML_TYPE_F32) {
return ctx->device->pipeline_conv2d_f32;
return ctx->device->pipeline_conv2d_f32[shape];
} else if (src0->type == GGML_TYPE_F16) {
return ctx->device->pipeline_conv2d_f16_f32;
return ctx->device->pipeline_conv2d_f16_f32[shape];
}
}
return nullptr;
@ -7325,29 +7501,8 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
} break;
case GGML_OP_CONV_2D:
{
// src0 - kernel: [KW, KH, Cin, Cout]
// src1 - input: [W, H, Cin, N]
// dst - result: [OW, OH, Cout, N]
// Copied from ggml.c: int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p, int d)
auto calc_conv_output_size = [](int64_t ins, int64_t ks, int s, int p, int d) -> int64_t {
return (ins + 2 * p - d * (ks - 1) - 1) / s + 1;
};
// parallelize in {OW/BS_K, OH/BS_NPQ, 1}
int64_t W = src1->ne[0];
int64_t H = src1->ne[1];
int64_t KW = src0->ne[0];
int64_t KH = src0->ne[1];
int64_t Cout = src0->ne[3];
int64_t N = src1->ne[3];
int64_t OH = calc_conv_output_size(H, KH, dst->op_params[1], dst->op_params[3], dst->op_params[5]);
int64_t OW = calc_conv_output_size(W, KW, dst->op_params[0], dst->op_params[2], dst->op_params[4]);
int64_t NPQ = N * OW * OH;
// Tile output matrix to (K/NB_K, NPQ/NB_NPQ, 1) workgroups
elements = { static_cast<uint32_t>(Cout), static_cast<uint32_t>(NPQ), 1 };
}
break;
elements = ggml_vk_get_conv_elements(dst);
} break;
case GGML_OP_ADD:
case GGML_OP_SUB:
case GGML_OP_DIV:

View file

@ -1,14 +1,18 @@
#version 450
#extension GL_EXT_control_flow_attributes : enable
#ifdef COOPMAT2
#extension GL_NV_cooperative_matrix2 : enable
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
#extension GL_KHR_memory_scope_semantics : enable
#endif
#ifdef USE_COLLECTIVES
# extension GL_KHR_shader_subgroup_shuffle : enable
#endif
#include "types.comp"
// Make spec constant
#define SHMEM_PAD 0
// shape notation: [dim(N), ..., dim(0)] -- stride(dim(j)) >= stride(dim(i)) if i > j
layout(binding = 0) readonly buffer A {
A_TYPE knl_data[];
@ -56,6 +60,12 @@ layout(push_constant) uniform parameter {
uint32_t nb1;
uint32_t nb2;
uint32_t nb3;
// fastdiv helper values
uint32_t KWmp; uint32_t KWL;
uint32_t KWKHmp; uint32_t KWKHL;
uint32_t OWmp; uint32_t OWL;
uint32_t OWOHmp; uint32_t OWOHL;
}
p;
@ -68,6 +78,7 @@ layout(constant_id = 3) const uint BS_NPQ = 128;
// Thread-tile sizes
layout(constant_id = 4) const uint TS_K = 8;
layout(constant_id = 5) const uint use_collectives = 1;
layout(constant_id = 6) const uint SHMEM_PAD = 4;
uint32_t tid = gl_LocalInvocationID.x;
const uint32_t WG_SIZE = gl_WorkGroupSize.x;
@ -85,6 +96,12 @@ uint32_t n_elems_out = K * NPQ;
// Number of blocktiles per input
uint32_t NB_CRS = splitWork(CRS, BS_CRS);
#ifdef COOPMAT2
#define SHMEM_TYPE float16_t
#else
#define SHMEM_TYPE float
#endif
const uint32_t Ash_stride = BS_CRS + SHMEM_PAD;
const uint32_t Bsh_stride = BS_NPQ + SHMEM_PAD;
@ -94,8 +111,8 @@ const uint32_t Bsh_numel = BS_CRS * BS_NPQ;
const uint32_t Ash_len = BS_K * Ash_stride;
const uint32_t Bsh_len = BS_CRS * Bsh_stride;
shared float Ash[Ash_len]; // K x CRS
shared float Bsh[Bsh_len]; // CRS x NPQ
shared SHMEM_TYPE Ash[Ash_len]; // K x CRS
shared SHMEM_TYPE Bsh[Bsh_len]; // CRS x NPQ
// Threadtile sizes
const uint32_t TS_NPQ = BS_K * BS_NPQ / WG_SIZE / TS_K;
@ -104,10 +121,6 @@ const uint32_t TS_NPQ = BS_K * BS_NPQ / WG_SIZE / TS_K;
const uint32_t NT_K = BS_K / TS_K;
const uint32_t NT_NPQ = BS_NPQ / TS_NPQ;
float regA[TS_K];
float regB[TS_NPQ];
float regC[TS_K][TS_NPQ];
/*
Compute
KxCRS @ CRSxNPQ = K x NPQ
@ -131,12 +144,44 @@ uint32_t Br = tid / BS_NPQ;
uint32_t Bc = tid % BS_NPQ;
const uint32_t BrpWg = WG_SIZE / BS_NPQ;
// see init_fastdiv_values in ggml-vulkan.cpp
uint fastdiv(uint n, uint mp, uint L) {
uint msbs, lsbs;
// msbs = mulhi(n, mp)
umulExtended(n, mp, msbs, lsbs);
return (msbs + n) >> L;
}
#ifdef COOPMAT2
#define ACC_TYPE float16_t
ACC_TYPE perElemOpStore(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem)
{
uint32_t K_idx = B_idx_K * BS_K + r;
uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + c;
uint32_t N_idx = fastdiv(NPQ_idx, p.OWOHmp, p.OWOHL); // divide by p.OH * p.OW;
uint32_t OH_idx = fastdiv(NPQ_idx - N_idx * p.OH * p.OW, p.OWmp, p.OWL); // divide by p.OW;
uint32_t OW_idx = NPQ_idx - N_idx * p.OH * p.OW - OH_idx * p.OW;
uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + K_idx * p.nb2 + N_idx * p.nb3;
if (K_idx < K && NPQ_idx < NPQ) {
dst_data[dst_idx] = D_TYPE(elem);
}
return elem;
}
#endif
void main() {
#ifdef COOPMAT2
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BS_K, BS_NPQ, gl_MatrixUseAccumulator> matC;
matC = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BS_K, BS_NPQ, gl_MatrixUseAccumulator>(0.0);
#else
float regC[TS_K][TS_NPQ];
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
regC[T_ly][T_lx] = 0.0;
}
}
#endif
/* Advance block in CRS dim */
for (uint32_t B_idx_CRS = 0; B_idx_CRS < NB_CRS; B_idx_CRS++) {
uint32_t CRS_idx_a;
@ -151,9 +196,9 @@ void main() {
uint32_t cached_KW_idx;
if (use_collectives == 1) {
cached_CRS_idx = B_idx_CRS * BS_CRS + gl_SubgroupInvocationID;
cached_Cin_idx = cached_CRS_idx / (p.KW * p.KH);
cached_Cin_idx = fastdiv(cached_CRS_idx, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH);
uint32_t cached_CRS_remainder = (cached_CRS_idx - cached_Cin_idx * p.KW * p.KH);
cached_KH_idx = cached_CRS_remainder / p.KW;
cached_KH_idx = fastdiv(cached_CRS_remainder, p.KWmp, p.KWL); // divide by p.KW;
cached_KW_idx = cached_CRS_remainder - cached_KH_idx * p.KW;
CRS_idx_a = subgroupShuffle(cached_CRS_idx, Ac);
@ -162,16 +207,16 @@ void main() {
KW_idx_a = subgroupShuffle(cached_KW_idx, Ac);
} else {
CRS_idx_a = B_idx_CRS * BS_CRS + Ac; // Global CRS_idx_a (column index of A)
Cin_idx_a = CRS_idx_a / (p.KW * p.KH);
Cin_idx_a = fastdiv(CRS_idx_a, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH);
uint32_t CRS_remainder = CRS_idx_a - Cin_idx_a * p.KW * p.KH;
KH_idx_a = CRS_remainder / p.KW;
KH_idx_a = fastdiv(CRS_remainder, p.KWmp, p.KWL); // divide by p.KW;
KW_idx_a = CRS_remainder - KH_idx_a * p.KW;
}
#else
CRS_idx_a = B_idx_CRS * BS_CRS + Ac; // Global CRS_idx_a (column index of A)
Cin_idx_a = CRS_idx_a / (p.KW * p.KH);
Cin_idx_a = fastdiv(CRS_idx_a, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH); / (p.KW * p.KH);
CRS_remainder = CRS_idx_a - Cin_idx_a * p.KW * p.KH;
KH_idx_a = CRS_remainder / p.KW;
KH_idx_a = fastdiv(CRS_remainder, p.KWmp, p.KWL); // divide by p.KW;
KW_idx_a = CRS_remainder - KH_idx_a * p.KW;
#endif
@ -185,16 +230,16 @@ void main() {
if (K_idx >= K || CRS_idx_a >= CRS) {
val = 0.0;
}
Ash[B_ly * Ash_stride + B_lx] = val;
Ash[B_ly * Ash_stride + B_lx] = SHMEM_TYPE(val);
}
/* Load input to B_block: (BS_CRS x BS_NPQ) */
for (uint32_t r_offset = 0; r_offset < BS_CRS; r_offset += BrpWg) {
UNROLL for (uint32_t r_offset = 0; r_offset < BS_CRS; r_offset += BrpWg) {
uint32_t B_ly = r_offset + Br; /* Row index of B block */
uint32_t B_lx = Bc;
uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + B_lx; /* Global NPQ index (column index of B) */
uint32_t N_idx = NPQ_idx / (p.OH * p.OW);
uint32_t N_idx = fastdiv(NPQ_idx, p.OWOHmp, p.OWOHL); // divide by p.OH * p.OW;
uint32_t NPQ_remainder = NPQ_idx - N_idx * p.OH * p.OW;
uint32_t OH_idx = NPQ_remainder / p.OW;
uint32_t OH_idx = fastdiv(NPQ_remainder, p.OWmp, p.OWL); // divide by p.OW;
uint32_t OW_idx = NPQ_remainder - OH_idx * p.OW;
uint32_t CRS_idx_b;
@ -209,16 +254,16 @@ void main() {
KW_idx_b = subgroupShuffle(cached_KW_idx, r_offset + Br);
} else {
CRS_idx_b = B_idx_CRS * BS_CRS + B_ly; /* Global CRS index (row index of B) */
Cin_idx_b = CRS_idx_b / (p.KW * p.KH);
Cin_idx_b = fastdiv(CRS_idx_b, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH);
uint32_t CRS_remainder = CRS_idx_b - Cin_idx_b * p.KW * p.KH;
KH_idx_b = CRS_remainder / p.KW;
KH_idx_b = fastdiv(CRS_remainder, p.KWmp, p.KWL); // divide by p.KW;
KW_idx_b = CRS_remainder - KH_idx_b * p.KW;
}
#else
CRS_idx_b = B_idx_CRS * BS_CRS + B_ly; /* Global CRS index (row index of B) */
Cin_idx_b = CRS_idx_b / (p.KW * p.KH);
Cin_idx_b = fastdiv(CRS_idx_b, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH);
uint32_t CRS_remainder = CRS_idx_b - Cin_idx_b * p.KW * p.KH;
KH_idx_b = CRS_remainder / p.KW;
KH_idx_b = fastdiv(CRS_remainder, p.KWmp, p.KWL); // divide by p.KW;
KW_idx_b = CRS_remainder - KH_idx_b * p.KW;
#endif
@ -230,36 +275,55 @@ void main() {
if (CRS_idx_b >= CRS || NPQ_idx >= NPQ || H_idx < 0 || H_idx >= p.H || W_idx < 0 || W_idx >= p.W) {
val = 0.0;
}
Bsh[B_ly * Bsh_stride + B_lx] = val;
Bsh[B_ly * Bsh_stride + B_lx] = SHMEM_TYPE(val);
}
barrier();
for (uint32_t CRS_lidx = 0; CRS_lidx < BS_CRS; CRS_lidx++) {
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
regA[T_ly] = Ash[(T_y * TS_K + T_ly) * Ash_stride + CRS_lidx];
}
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
regB[T_lx] = Bsh[CRS_lidx * Bsh_stride + T_x * TS_NPQ + T_lx];
}
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
#ifdef COOPMAT2
coopmat<float16_t, gl_ScopeWorkgroup, BS_K, BS_CRS, gl_MatrixUseA> matA;
coopmat<float16_t, gl_ScopeWorkgroup, BS_CRS, BS_NPQ, gl_MatrixUseB> matB;
coopMatLoad(matA, Ash, 0, Ash_stride, gl_CooperativeMatrixLayoutRowMajor);
coopMatLoad(matB, Bsh, 0, Bsh_stride, gl_CooperativeMatrixLayoutRowMajor);
matC = coopMatMulAdd(matA, matB, matC);
#else
if (T_y * TS_K < K) {
UNROLL for (uint32_t CRS_lidx = 0; CRS_lidx < BS_CRS; CRS_lidx++) {
float regA[TS_K];
float regB[TS_NPQ];
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
regA[T_ly] = Ash[(T_y * TS_K + T_ly) * Ash_stride + CRS_lidx];
}
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
regC[T_ly][T_lx] = fma(regA[T_ly], regB[T_lx], regC[T_ly][T_lx]);
regB[T_lx] = Bsh[CRS_lidx * Bsh_stride + T_x * TS_NPQ + T_lx];
}
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
regC[T_ly][T_lx] = fma(regA[T_ly], regB[T_lx], regC[T_ly][T_lx]);
}
}
}
}
#endif
barrier();
}
/* Save C* */
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
uint32_t K_idx = B_idx_K * BS_K + T_y * TS_K + T_ly;
uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + T_x * TS_NPQ + T_lx;
uint32_t N_idx = NPQ_idx / (p.OH * p.OW);
uint32_t OH_idx = (NPQ_idx - N_idx * p.OH * p.OW) / p.OW;
uint32_t OW_idx = NPQ_idx - N_idx * p.OH * p.OW - OH_idx * p.OW;
uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + K_idx * p.nb2 + N_idx * p.nb3;
if (K_idx < K && NPQ_idx < NPQ) {
dst_data[dst_idx] = regC[T_ly][T_lx];
#ifdef COOPMAT2
coopMatPerElementNV(matC, matC, perElemOpStore);
#else
if (T_y * TS_K < K) {
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
uint32_t K_idx = B_idx_K * BS_K + T_y * TS_K + T_ly;
uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + T_x * TS_NPQ + T_lx;
uint32_t N_idx = fastdiv(NPQ_idx, p.OWOHmp, p.OWOHL); // divide by p.OH * p.OW;
uint32_t OH_idx = fastdiv(NPQ_idx - N_idx * p.OH * p.OW, p.OWmp, p.OWL); // divide by p.OW;
uint32_t OW_idx = NPQ_idx - N_idx * p.OH * p.OW - OH_idx * p.OW;
uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + K_idx * p.nb2 + N_idx * p.nb3;
if (K_idx < K && NPQ_idx < NPQ) {
dst_data[dst_idx] = regC[T_ly][T_lx];
}
}
}
}
#endif
}

View file

@ -26,6 +26,9 @@ layout (push_constant) uniform parameter
uint ne12;
uint b_offset;
uint d_offset;
uint nb03;
uint nb13;
uint nb23;
} p;
shared FLOAT_TYPE tmp[BLOCK_SIZE];
@ -34,6 +37,7 @@ void main() {
const uint tid = gl_LocalInvocationID.x;
const uint row_x = gl_GlobalInvocationID.y;
const uint channel = gl_GlobalInvocationID.z;
const uint i3 = gl_WorkGroupID.x;
const uint channel_x = channel / p.channel_x_divisor;
const uint channel_y = channel % p.ne12;
@ -41,7 +45,7 @@ void main() {
const uint nrows_dst = p.nrows_x;
const uint row_dst = row_x;
const uint idst = channel*nrows_dst + row_dst;
const uint idst = i3*p.nb23 + channel*nrows_dst + row_dst;
FLOAT_TYPE temp = 0.0f;
@ -58,8 +62,8 @@ void main() {
const uint row_y = col_x;
const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
const uint iy = channel_y*p.channel_stride_y + row_y;
const uint ix = i3*p.nb03 + channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
const uint iy = i3*p.nb13 + channel_y*p.channel_stride_y + row_y;
const vec4 av4 = vec4(data_a_v4[ix / 4]);
const vec4 bv4 = vec4(data_b_v4[iy / 4]);
@ -74,8 +78,8 @@ void main() {
const uint row_y = col_x;
const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
const uint iy = channel_y*p.channel_stride_y + row_y;
const uint ix = i3*p.nb03 + channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
const uint iy = i3*p.nb13 + channel_y*p.channel_stride_y + row_y;
const vec4 av4 = vec4(data_a_v4[ix / 4]);
const vec4 bv4 = vec4(data_b_v4[iy / 4]);
@ -91,8 +95,8 @@ void main() {
const uint row_y = col_x;
const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
const uint iy = channel_y*p.channel_stride_y + row_y;
const uint ix = i3*p.nb03 + channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
const uint iy = i3*p.nb13 + channel_y*p.channel_stride_y + row_y;
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]);

View file

@ -669,8 +669,16 @@ void process_shaders() {
string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
string_to_spv("conv2d_f32", "conv2d_mm.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}});
string_to_spv("conv2d_f16_f32", "conv2d_mm.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}});
string_to_spv("conv2d_f32_unroll", "conv2d_mm.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", "[[unroll]]"}});
string_to_spv("conv2d_f16_f32_unroll", "conv2d_mm.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", "[[unroll]]"}});
string_to_spv("conv2d_f32", "conv2d_mm.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", ""}});
string_to_spv("conv2d_f16_f32", "conv2d_mm.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", ""}});
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
string_to_spv("conv2d_f32", "conv2d_mm.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", "[[unroll]]"}, {"COOPMAT2", "1"}}, true, false, true);
string_to_spv("conv2d_f16_f32", "conv2d_mm.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", "[[unroll]]"}, {"COOPMAT2", "1"}}, true, false, true);
#endif
string_to_spv("conv2d_dw_whcn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}}));
string_to_spv("conv2d_dw_cwhn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}}));

View file

@ -33,6 +33,7 @@ class TensorNameMap:
"language_model.model.embed_tokens", # llama4
"encoder", # neobert
"model.transformer.wte", # llada
"embed_tokens", # qwen3-embedding
),
# Token type embeddings
@ -143,6 +144,7 @@ class TensorNameMap:
"transformer_encoder.{bid}.attention_norm", # neobert
"model.layers.{bid}.operator_norm", # lfm2
"model.transformer.blocks.{bid}.attn_norm", # llada
"layers.{bid}.input_layernorm", # qwen3-embedding
),
# Attention norm 2
@ -188,6 +190,7 @@ class TensorNameMap:
"transformer.h.{bid}.attn.attention.q_proj", # exaone
"model.layers.{bid}.self_attn.q_proj", # llama4
"model.transformer.blocks.{bid}.q_proj", # llada
"layers.{bid}.self_attn.q_proj", # qwen3-embedding
),
# Attention key
@ -205,6 +208,7 @@ class TensorNameMap:
"transformer.h.{bid}.attn.attention.k_proj", # exaone
"model.layers.{bid}.self_attn.k_proj", # llama4
"model.transformer.blocks.{bid}.k_proj", # llada
"layers.{bid}.self_attn.k_proj", # qwen3-embedding
),
# Attention value
@ -221,6 +225,7 @@ class TensorNameMap:
"transformer.h.{bid}.attn.attention.v_proj", # exaone
"model.layers.{bid}.self_attn.v_proj", # llama4
"model.transformer.blocks.{bid}.v_proj", # llada
"layers.{bid}.self_attn.v_proj", # qwen3-embedding
),
# Attention output
@ -254,6 +259,7 @@ class TensorNameMap:
"model.layers.{bid}.self_attn.o_proj", # llama4
"transformer_encoder.{bid}.wo", # neobert
"model.transformer.blocks.{bid}.attn_out", # llada
"layers.{bid}.self_attn.o_proj", # qwen3-embedding
),
# Attention output norm
@ -300,6 +306,7 @@ class TensorNameMap:
"transformer_encoder.{bid}.ffn_norm", # neobert
"model.layers.layers.{bid}.pre_mlp_norm", # plamo2
"model.transformer.blocks.{bid}.ff_norm", # llada
"layers.{bid}.post_attention_layernorm", # qwen3-embedding
),
# Post feed-forward norm
@ -373,7 +380,8 @@ class TensorNameMap:
"model.layers.{bid}.feed_forward.up_proj", # llama4 jamba granite-hybrid
"transformer_encoder.{bid}.ffn.w12", # neobert
"model.layers.{bid}.block_sparse_moe.up", # smallthinker
"model.transformer.blocks.{bid}.up_proj", # llada
"model.transformer.blocks.{bid}.up_proj", # llada
"layers.{bid}.mlp.up_proj", # qwen3-embedding
),
MODEL_TENSOR.FFN_UP_EXP: (
@ -416,6 +424,7 @@ class TensorNameMap:
"model.layers.{bid}.feed_forward.gate_proj", # llama4 jamba granite-hybrid
"model.layers.{bid}.block_sparse_moe.gate", # smallthinker
"model.transformer.blocks.{bid}.ff_proj", # llada
"layers.{bid}.mlp.gate_proj", # qwen3-embedding
),
MODEL_TENSOR.FFN_GATE_EXP: (
@ -465,7 +474,8 @@ class TensorNameMap:
"model.layers.{bid}.feed_forward.down_proj", # llama4 jamba granite-hybrid
"transformer_encoder.{bid}.ffn.w3", # neobert
"model.layers.{bid}.block_sparse_moe.down", # smallthinker
"model.transformer.blocks.{bid}.ff_out", # llada
"model.transformer.blocks.{bid}.ff_out", # llada
"layers.{bid}.mlp.down_proj", # qwen3-embedding
),
MODEL_TENSOR.FFN_DOWN_EXP: (
@ -497,6 +507,7 @@ class TensorNameMap:
"encoder.layer.{bid}.attention.self.layer_norm_q", # jina-bert-v2
"transformer.layers.{bid}.attn.q_norm", # openelm
"model.layers.layers.{bid}.mixer.q", # plamo2
"layers.{bid}.self_attn.q_norm", # qwen3-embedding
),
MODEL_TENSOR.ATTN_K_NORM: (
@ -508,6 +519,7 @@ class TensorNameMap:
"encoder.layer.{bid}.attention.self.layer_norm_k", # jina-bert-v2
"transformer.layers.{bid}.attn.k_norm", # openelm
"model.layers.layers.{bid}.mixer.k", # plamo2
"layers.{bid}.self_attn.k_norm", # qwen3-embedding
),
MODEL_TENSOR.ROPE_FREQS: (

View file

@ -312,7 +312,11 @@ class SpecialVocab:
with open(config_file, encoding = 'utf-8') as f:
config = json.load(f)
for typ in self.special_token_types:
self._set_special_token(typ, config.get(f'{typ}_token_id'))
token_id = config.get(f'{typ}_token_id')
# If not found at root, check in text_config (for multimodal models like Kimi-VL)
if token_id is None and 'text_config' in config:
token_id = config['text_config'].get(f'{typ}_token_id')
self._set_special_token(typ, token_id)
return True

View file

@ -105,7 +105,7 @@ llama_context::llama_context(
{
const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS");
supports_set_rows = LLAMA_SET_ROWS ? (atoi(LLAMA_SET_ROWS) != 0) : false;
supports_set_rows = LLAMA_SET_ROWS ? (atoi(LLAMA_SET_ROWS) != 0) : supports_set_rows;
if (!supports_set_rows && !cparams.kv_unified) {
LLAMA_LOG_WARN("%s: non-unified KV cache requires ggml_set_rows() - forcing unified KV cache\n", __func__);

View file

@ -289,7 +289,7 @@ private:
// env: LLAMA_SET_ROWS (temporary)
// ref: https://github.com/ggml-org/llama.cpp/pull/14285
bool supports_set_rows = false;
bool supports_set_rows = true;
// env: LLAMA_GRAPH_REUSE_DISABLE
bool graph_reuse_disable = false;

View file

@ -183,7 +183,7 @@ llama_kv_cache_unified::llama_kv_cache_unified(
const size_t memory_size_k = size_k_bytes();
const size_t memory_size_v = size_v_bytes();
LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u/%2u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u/%u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(), n_seq_max, n_stream,
ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
@ -193,7 +193,7 @@ llama_kv_cache_unified::llama_kv_cache_unified(
debug = LLAMA_KV_CACHE_DEBUG ? atoi(LLAMA_KV_CACHE_DEBUG) : 0;
const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS");
supports_set_rows = LLAMA_SET_ROWS ? atoi(LLAMA_SET_ROWS) != 0 : 0;
supports_set_rows = LLAMA_SET_ROWS ? atoi(LLAMA_SET_ROWS) != 0 : supports_set_rows;
if (!supports_set_rows) {
// ref: https://github.com/ggml-org/llama.cpp/pull/14363

View file

@ -230,7 +230,7 @@ private:
// env: LLAMA_SET_ROWS (temporary)
// ref: https://github.com/ggml-org/llama.cpp/pull/14285
bool supports_set_rows = false;
bool supports_set_rows = true;
const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;

View file

@ -25,6 +25,7 @@ llama_memory_hybrid::llama_memory_hybrid(
/* common */
uint32_t n_seq_max,
bool offload,
bool unified,
/* layer filters */
layer_filter_cb && filter_attn,
layer_filter_cb && filter_recr) :
@ -38,7 +39,7 @@ llama_memory_hybrid::llama_memory_hybrid(
type_v,
v_trans,
offload,
1,
unified,
kv_size,
n_seq_max,
n_pad,

View file

@ -39,6 +39,7 @@ public:
/* common */
uint32_t n_seq_max,
bool offload,
bool unified,
/* layer filters */
layer_filter_cb && filter_attn = nullptr,
layer_filter_cb && filter_recr = nullptr);

View file

@ -904,6 +904,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
} break;
case LLM_ARCH_QWEN3:
{
ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false);
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
switch (hparams.n_layer) {
case 28: type = hparams.n_embd == 1024 ? LLM_TYPE_0_6B : LLM_TYPE_1_7B; break;
@ -17697,6 +17698,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
/* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max),
/* n_seq_max */ cparams.n_seq_max,
/* offload */ cparams.offload_kqv,
/* unified */ cparams.kv_unified,
/* filter_attn */ (arch == LLM_ARCH_FALCON_H1) ? [&](int32_t) { return true; } : (llama_memory_hybrid::layer_filter_cb)nullptr,
/* filter_recr */ (arch == LLM_ARCH_FALCON_H1) ? [&](int32_t) { return true; } : (llama_memory_hybrid::layer_filter_cb)nullptr);
} else {

View file

@ -2092,7 +2092,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
tokenizer_pre == "gigachat" ||
tokenizer_pre == "jina-v2-es" ||
tokenizer_pre == "jina-v2-de" ||
tokenizer_pre == "a.x-4.0") {
tokenizer_pre == "a.x-4.0" ||
tokenizer_pre == "mellum") {
pre_type = LLAMA_VOCAB_PRE_TYPE_GPT2;
} else if (
tokenizer_pre == "jina-v1-en" ||

View file

@ -4249,9 +4249,6 @@ int main(int argc, char ** argv) {
// process prompt
std::vector<server_tokens> inputs;
if (oaicompat && !prompt.is_string()) {
throw std::runtime_error("prompt must be a string");
}
if (oaicompat && has_mtmd) {
// multimodal