Merge branch 'upstream' into concedo_experimental

# Conflicts:
#	.github/workflows/build.yml
#	ggml/cmake/ggml-config.cmake.in
#	ggml/src/ggml-cann/CMakeLists.txt
#	ggml/src/ggml-cann/common.h
#	ggml/src/ggml-cann/ggml-cann.cpp
#	ggml/src/ggml-cuda/fattn.cu
#	ggml/src/ggml-opencl/CMakeLists.txt
#	ggml/src/ggml-opencl/ggml-opencl.cpp
#	requirements/requirements-convert_hf_to_gguf.txt
#	scripts/compare-llama-bench.py
#	tests/test-chat-template.cpp
#	tests/test-chat.cpp
#	tools/llama-bench/llama-bench.cpp
This commit is contained in:
Concedo 2025-08-07 21:23:09 +08:00
commit 8a71eb03c0
29 changed files with 1350 additions and 291 deletions

View file

@ -55,7 +55,15 @@ bool common_chat_msg_parser::add_tool_call(const std::string & name, const std::
bool common_chat_msg_parser::add_tool_call(const json & tool_call) {
std::string name = tool_call.contains("name") ? tool_call.at("name") : "";
std::string id = tool_call.contains("id") ? tool_call.at("id") : "";
std::string arguments = tool_call.contains("arguments") ? tool_call.at("arguments") : "";
std::string arguments = "";
if (tool_call.contains("arguments")) {
if (tool_call.at("arguments").is_object()) {
arguments = tool_call.at("arguments").dump();
} else {
arguments = tool_call.at("arguments");
}
}
return add_tool_call(name, id, arguments);
}

View file

@ -606,6 +606,7 @@ const char * common_chat_format_name(common_chat_format format) {
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: return "Functionary v3.1 Llama 3.1";
case COMMON_CHAT_FORMAT_HERMES_2_PRO: return "Hermes 2 Pro";
case COMMON_CHAT_FORMAT_COMMAND_R7B: return "Command R7B";
case COMMON_CHAT_FORMAT_GRANITE: return "Granite";
case COMMON_CHAT_FORMAT_GPT_OSS: return "GPT-OSS";
default:
throw std::runtime_error("Unknown chat format");
@ -618,6 +619,7 @@ const char * common_reasoning_format_name(common_reasoning_format format) {
case COMMON_REASONING_FORMAT_AUTO: return "auto";
case COMMON_REASONING_FORMAT_DEEPSEEK: return "deepseek";
case COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY: return "deepseek-legacy";
case COMMON_REASONING_FORMAT_GRANITE: return "granite";
default:
throw std::runtime_error("Unknown reasoning format");
}
@ -1734,6 +1736,124 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) {
builder.add_content(builder.consume_rest());
}
static common_chat_params common_chat_params_init_granite(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;
// Pass thinking context for Granite template
json additional_context = {
{"thinking", inputs.enable_thinking},
};
data.prompt = apply(tmpl, inputs, /* messages_override= */ std::nullopt, /* tools_override= */ std::nullopt, additional_context);
data.format = COMMON_CHAT_FORMAT_GRANITE;
if (string_ends_with(data.prompt, "<think>\n") || string_ends_with(data.prompt, "<think>")) {
if (!inputs.enable_thinking) {
data.prompt += "</think>";
} else {
data.thinking_forced_open = true;
}
}
if (!inputs.tools.is_null()) {
// Granite uses <|tool_call|> followed by JSON list
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules;
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool.at("function");
std::string name = function.at("name");
auto parameters = function.at("parameters");
builder.resolve_refs(parameters);
tool_rules.push_back(builder.add_rule(name + "-call", builder.add_schema(name +
"-args", {
{"type", "object"},
{"properties", {
{"name", {{"const", name}}},
{"arguments", parameters},
}},
{"required", json::array({"name", "arguments"})},
})));
});
auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | "));
auto tool_list = builder.add_rule("tool_list", "\"[\" space " + tool_call + " (\",\" space " + tool_call + ")* space \"]\"");
if (data.thinking_forced_open) {
builder.add_rule("root", "\"</think>\" space \"<response>\" space [^<]* \"</response>\" space \"<|tool_call|>\" space " + tool_list);
} else {
builder.add_rule("root", "\"<|tool_call|>\" space " + tool_list);
}
data.grammar_triggers.push_back({
COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
"<|tool_call|>"
});
data.preserved_tokens = {
"<think>",
"</think>",
"<response>",
"</response>",
"<|tool_call|>",
};
});
} else {
// Handle thinking tags for non-tool responses
if (data.thinking_forced_open && inputs.enable_thinking) {
data.grammar_lazy = false;
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
builder.add_rule("root", "\"</think>\" space \"<response>\" space .* \"</response>\" space");
});
data.preserved_tokens = {
"<think>",
"</think>",
"<response>",
"</response>",
};
}
}
return data;
}
static void common_chat_parse_granite(common_chat_msg_parser & builder) {
// Parse thinking tags
builder.try_parse_reasoning("<think>", "</think>");
// Parse response tags using regex
static const common_regex response_regex("<response>([\\s\\S]*?)</response>");
if (auto res = builder.try_find_regex(response_regex)) {
// Extract the content between the tags (capture group 1)
auto content = builder.str(res->groups[1]);
builder.add_content(content);
builder.move_to(res->groups[0].end);
}
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
// Look for tool calls
static const common_regex tool_call_regex(regex_escape("<|tool_call|>"));
if (auto res = builder.try_find_regex(tool_call_regex)) {
builder.move_to(res->groups[0].end);
// Expect JSON array of tool calls
auto tool_calls_data = builder.consume_json();
if (tool_calls_data.json.is_array()) {
if (!builder.add_tool_calls(tool_calls_data.json)) {
builder.add_content("<|tool_call|>" + tool_calls_data.json.dump());
}
} else {
builder.add_content("<|tool_call|>" + tool_calls_data.json.dump());
}
} else {
builder.add_content(builder.consume_rest());
}
}
static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;
data.prompt = apply(tmpl, inputs);
@ -1805,6 +1925,11 @@ static common_chat_params common_chat_templates_apply_jinja(
return common_chat_params_init_command_r7b(tmpl, params);
}
// Granite (IBM) - detects thinking / tools support
if (src.find("elif thinking") != std::string::npos && src.find("<|tool_call|>") != std::string::npos) {
return common_chat_params_init_granite(tmpl, params);
}
// Hermes 2/3 Pro, Qwen 2.5 Instruct (w/ tools)
if (src.find("<tool_call>") != std::string::npos && params.json_schema.is_null()) {
return common_chat_params_init_hermes_2_pro(tmpl, params);
@ -1865,6 +1990,7 @@ static common_chat_params common_chat_templates_apply_legacy(
int alloc_size = 0;
std::vector<llama_chat_message> chat;
std::vector<std::string> contents;
for (const auto & msg : inputs.messages) {
auto content = msg.content;
for (const auto & part : msg.content_parts) {
@ -1966,6 +2092,9 @@ static void common_chat_parse(common_chat_msg_parser & builder) {
case COMMON_CHAT_FORMAT_COMMAND_R7B:
common_chat_parse_command_r7b(builder);
break;
case COMMON_CHAT_FORMAT_GRANITE:
common_chat_parse_granite(builder);
break;
case COMMON_CHAT_FORMAT_GPT_OSS:
common_chat_parse_gpt_oss(builder);
break;

View file

@ -109,6 +109,7 @@ enum common_chat_format {
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
COMMON_CHAT_FORMAT_HERMES_2_PRO,
COMMON_CHAT_FORMAT_COMMAND_R7B,
COMMON_CHAT_FORMAT_GRANITE,
COMMON_CHAT_FORMAT_GPT_OSS,
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats

View file

@ -235,6 +235,7 @@ enum common_reasoning_format {
COMMON_REASONING_FORMAT_AUTO,
COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY, // Extract thinking tag contents and return as `message.reasoning_content`, or leave inline in <think> tags in stream mode
COMMON_REASONING_FORMAT_DEEPSEEK, // Extract thinking tag contents and return as `message.reasoning_content`, including in streaming deltas.
COMMON_REASONING_FORMAT_GRANITE, // Extract thinking tag contents and return as `message.reasoning_content`, including in streaming deltas.
};
struct common_params {

View file

@ -1077,6 +1077,11 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
}
}
}
// if the node is still unassigned, assign it to the first backend that supports it
for (int b = 0; b < sched->n_backends && *cur_backend_id == -1; b++) {
ggml_backend_sched_set_if_supported(sched, node, b, cur_backend_id);
}
GGML_ASSERT(*cur_backend_id != -1);
}
// pass 5: split graph, find tensors that need to be copied
@ -1104,7 +1109,7 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
const int node_backend_id = tensor_backend_id(node);
assert(node_backend_id != -1); // all nodes should be assigned by now, this can happen if there is no CPU fallback
GGML_ASSERT(node_backend_id != -1); // all nodes should be assigned by now, this can happen if there is no CPU fallback
// check if we should start a new split based on the sources of the current node
bool need_new_split = false;
@ -1162,7 +1167,7 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
size_t src_id = hash_id(src);
const int src_backend_id = sched->hv_tensor_backend_ids[src_id];
assert(src_backend_id != -1); // all inputs should be assigned by now
GGML_ASSERT(src_backend_id != -1); // all inputs should be assigned by now
if (src->flags & GGML_TENSOR_FLAG_INPUT && sched->n_copies > 1) {
if (tensor_id_copy(src_id, src_backend_id, 0) == NULL) {

View file

@ -35,7 +35,7 @@
// ggml-backend interface
std::vector<ggml_backend_buffer_type_t>& ggml_backend_cpu_get_extra_buffers_type() {
std::vector<ggml_backend_buffer_type_t> & ggml_backend_cpu_get_extra_buffer_types() {
static std::vector<ggml_backend_buffer_type_t> bufts = []() {
std::vector<ggml_backend_buffer_type_t> bufts;
@ -57,8 +57,6 @@ std::vector<ggml_backend_buffer_type_t>& ggml_backend_cpu_get_extra_buffers_type
}
#endif
bufts.push_back(NULL);
return bufts;
}();
@ -66,14 +64,20 @@ std::vector<ggml_backend_buffer_type_t>& ggml_backend_cpu_get_extra_buffers_type
}
static ggml_backend_buffer_type_t * ggml_backend_cpu_device_get_extra_buffers_type(ggml_backend_dev_t device) {
return ggml_backend_cpu_get_extra_buffers_type().data();
static std::vector<ggml_backend_buffer_type_t> extra_bufts = [] {
std::vector<ggml_backend_buffer_type_t> bufts = ggml_backend_cpu_get_extra_buffer_types();
bufts.push_back(nullptr);
return bufts;
}();
return extra_bufts.data();
GGML_UNUSED(device);
}
static bool ggml_backend_cpu_is_extra_buffer_type(ggml_backend_buffer_type_t buft) {
for (auto * extra : ggml_backend_cpu_get_extra_buffers_type()) {
if (extra && extra == buft) {
for (auto * extra : ggml_backend_cpu_get_extra_buffer_types()) {
if (extra == buft) {
return true;
}
}
@ -397,20 +401,13 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st
return true;
}
// extra_buffer_op?
for (auto extra : ggml_backend_cpu_get_extra_buffers_type()) {
if (extra) {
auto buf_extra = (ggml::cpu::extra_buffer_type*) extra->context;
if (buf_extra && buf_extra->supports_op(dev, op)) {
return true;
}
}
}
// the other case need host buffer.
for (int i = 0; i < GGML_MAX_SRC; i++) {
if (op->src[i] && op->src[i]->buffer && !ggml_backend_buft_is_host(op->src[i]->buffer->buft)) {
return false;
// check extra buffer types
// note: only the first sources are checked for extra buffer types to reduce overhead, increase if necessary
for (int i = 0; i < 4; i++) {
if (op->src[i] && op->src[i]->buffer &&
ggml_backend_cpu_is_extra_buffer_type(op->src[i]->buffer->buft)) {
auto * buf_extra = (ggml::cpu::extra_buffer_type *) op->src[i]->buffer->buft->context;
return buf_extra->supports_op(dev, op);
}
}

View file

@ -10,7 +10,7 @@ extra_buffer_type::~extra_buffer_type() {}
} // namespace ggml::cpu
bool ggml_cpu_extra_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) {
for (auto extra : ggml_backend_cpu_get_extra_buffers_type()) {
for (auto extra : ggml_backend_cpu_get_extra_buffer_types()) {
if (extra && extra->context) {
auto buf_extra = (ggml::cpu::extra_buffer_type *) extra->context;
auto tensor_traits = buf_extra->get_tensor_traits(op);
@ -23,7 +23,7 @@ bool ggml_cpu_extra_compute_forward(struct ggml_compute_params * params, struct
}
bool ggml_cpu_extra_work_size(int n_threads, const struct ggml_tensor * op, size_t * size) {
for (auto extra : ggml_backend_cpu_get_extra_buffers_type()) {
for (auto extra : ggml_backend_cpu_get_extra_buffer_types()) {
if (extra && extra->context) {
auto buf_extra = (ggml::cpu::extra_buffer_type *) extra->context;
auto tensor_traits = buf_extra->get_tensor_traits(op);

View file

@ -33,6 +33,6 @@ class extra_buffer_type {
} // namespace ggml::cpu
// implemented in ggml-cpu.cpp.
std::vector<ggml_backend_buffer_type_t> & ggml_backend_cpu_get_extra_buffers_type();
std::vector<ggml_backend_buffer_type_t> & ggml_backend_cpu_get_extra_buffer_types();
#endif

View file

@ -237,9 +237,13 @@ typedef float2 dfloat2;
#endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA)
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
#define NEW_MMA_AVAILABLE
#define TURING_MMA_AVAILABLE
#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#define AMPERE_MMA_AVAILABLE
#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#define CP_ASYNC_AVAILABLE
#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
@ -307,10 +311,14 @@ static bool amd_mfma_available(const int cc) {
}
// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
static bool new_mma_available(const int cc) {
static bool turing_mma_available(const int cc) {
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
}
static bool ampere_mma_available(const int cc) {
return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE;
}
static bool cp_async_available(const int cc) {
return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE;
}

View file

@ -418,7 +418,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
float * const __restrict__ KQ_max,
float * const __restrict__ KQ_rowsum,
const int kb0) {
#ifdef NEW_MMA_AVAILABLE
#ifdef TURING_MMA_AVAILABLE
typedef fattn_mma_f16_config<DKQ, DV> c;
#ifdef CP_ASYNC_AVAILABLE
@ -776,7 +776,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
GGML_UNUSED(VKQ_C); GGML_UNUSED(KQ_max); GGML_UNUSED(KQ_rowsum);
GGML_UNUSED(kb0); GGML_UNUSED(tile_Q);
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
#endif // TURING_MMA_AVAILABLE
}
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup>
@ -800,7 +800,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
const int jt,
const int kb0_start,
const int kb0_stop) {
#ifdef NEW_MMA_AVAILABLE
#ifdef TURING_MMA_AVAILABLE
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
typedef fattn_mma_f16_config<DKQ, DV> c;
@ -1196,7 +1196,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
GGML_UNUSED(stride_Q2); GGML_UNUSED(stride_K); GGML_UNUSED(stride_V); GGML_UNUSED(stride_mask);
GGML_UNUSED(jt); GGML_UNUSED(kb0_start); GGML_UNUSED(kb0_stop);
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
#endif // TURING_MMA_AVAILABLE
}
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla>
@ -1223,7 +1223,7 @@ static __global__ void flash_attn_ext_f16(
const int32_t nb21, const int32_t nb22, const int64_t nb23,
const int32_t ne31, const int32_t ne32, const int32_t ne33,
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
#if defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
#if defined(FLASH_ATTN_AVAILABLE) && defined(TURING_MMA_AVAILABLE)
// Skip unused kernel variants for faster compilation:
if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) {
@ -1354,7 +1354,7 @@ static __global__ void flash_attn_ext_f16(
GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33);
NO_DEVICE_CODE;
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(TURING_MMA_AVAILABLE)
}
template <int DKQ, int DV, int ncols1, int ncols2>

View file

@ -327,7 +327,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations
const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16;
const bool mma_faster_for_rtx4000 = Q->ne[3] > 1 || (Q->ne[2] > 4*K->ne[2] && K->ne[1] >= 8192);
const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && !mma_needs_data_conversion &&
const bool mma_faster_for_bs1 = turing_mma_available(cc) && gqa_opt_applies && !mma_needs_data_conversion &&
(cc < GGML_CUDA_CC_ADA_LOVELACE || mma_faster_for_rtx4000);
const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % (2*warp_size) == 0;
if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) {
@ -340,7 +340,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
}
// The MMA implementation needs Turing or newer, use the old WMMA code for Volta:
if (ggml_cuda_highest_compiled_arch(cc) <= GGML_CUDA_CC_TURING || cc == GGML_CUDA_CC_TURING || (fp16_mma_available(cc) && !new_mma_available(cc))) { //kcpp: use wmma to fix cu11 incoherence
if (ggml_cuda_highest_compiled_arch(cc) <= GGML_CUDA_CC_TURING || cc == GGML_CUDA_CC_TURING || (fp16_mma_available(cc) && !turing_mma_available(cc))) { //kcpp: use wmma to fix cu11 incoherence
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
return;
}

View file

@ -24,8 +24,9 @@ bool g_mul_mat_q = true;
#include "ggml-cuda/fattn.cuh"
#include "ggml-cuda/getrows.cuh"
#include "ggml-cuda/im2col.cuh"
#include "ggml-cuda/mmf.cuh"
#include "ggml-cuda/mmq.cuh"
#include "ggml-cuda/mmv.cuh"
#include "ggml-cuda/mmvf.cuh"
#include "ggml-cuda/mmvq.cuh"
#include "ggml-cuda/norm.cuh"
#include "ggml-cuda/opt-step-adamw.cuh"
@ -2009,7 +2010,9 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
const bool bad_padding_clear = ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE
&& ggml_nbytes(src0) != ggml_backend_buffer_get_alloc_size(src0->buffer, src0) && src0->view_src;
bool use_mul_mat_vec = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16)
bool use_mul_mat_vec_f = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16)
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
bool use_mul_mat_f = !ggml_is_quantized(src0->type)
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
@ -2029,14 +2032,18 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
}
const int cc = ggml_cuda_info().devices[id].cc;
const int warp_size = ggml_cuda_info().devices[id].warp_size;
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
use_mul_mat_vec = use_mul_mat_vec && ggml_cuda_should_use_mmv(src0->type, cc, src0->ne, src1->ne[1]);
use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src1->ne[1]);
use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src1->ne[1]);
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
}
} else {
const int cc = ggml_cuda_info().devices[ctx.device].cc;
const int warp_size = ggml_cuda_info().devices[ctx.device].warp_size;
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
use_mul_mat_vec = use_mul_mat_vec && ggml_cuda_should_use_mmv(src0->type, cc, src0->ne, src1->ne[1]);
use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src1->ne[1]);
use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src1->ne[1]);
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
}
@ -2054,10 +2061,12 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
bool use_batched_cublas_bf16 = src0->type == GGML_TYPE_BF16 && bf16_mma_hardware_available(cc);
bool use_batched_cublas_f32 = src0->type == GGML_TYPE_F32;
if (!split && use_mul_mat_vec) {
if (!split && use_mul_mat_vec_f) {
// the custom F16 vector kernel can be used over batched cuBLAS GEMM
// but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention)
ggml_cuda_mul_mat_vec(ctx, src0, src1, nullptr, dst);
ggml_cuda_mul_mat_vec_f(ctx, src0, src1, nullptr, dst);
} else if (!split && use_mul_mat_f) {
ggml_cuda_mul_mat_f(ctx, src0, src1, nullptr, dst);
} else if (!split && use_mul_mat_vec_q) {
ggml_cuda_mul_mat_vec_q(ctx, src0, src1, nullptr, dst);
} else if (!split && use_mul_mat_q) {
@ -2066,8 +2075,8 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
&& !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
// general KQ + KQV multi-batch without FlashAttention
ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst);
} else if (use_mul_mat_vec) {
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec, nullptr);
} else if (use_mul_mat_vec_f) {
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_f, nullptr);
} else if (use_mul_mat_vec_q) {
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, quantize_row_q8_1_cuda);
} else if (use_mul_mat_q) {
@ -2095,7 +2104,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
if (ggml_is_quantized(src0->type)) {
ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst);
} else {
ggml_cuda_mul_mat_vec(ctx, src0, src1, ids, dst);
ggml_cuda_mul_mat_vec_f(ctx, src0, src1, ids, dst);
}
return;
}
@ -3521,7 +3530,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
#endif // FLASH_ATTN_AVAILABLE
if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;
if (!new_mma_available(cc)) {
if (!turing_mma_available(cc)) {
return false;
}
const int gqa_ratio = op->src[0]->ne[2] / op->src[1]->ne[2];

View file

@ -23,13 +23,13 @@
static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) {
int ret = 0;
#ifdef NEW_MMA_AVAILABLE
#ifdef TURING_MMA_AVAILABLE
asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
: "=r"(ret) : "r"(x));
#else
GGML_UNUSED(x);
NO_DEVICE_CODE;
#endif // defined(NEW_MMA_AVAILABLE)
#endif // defined(TURING_MMA_AVAILABLE)
return ret;
}
@ -167,6 +167,38 @@ namespace ggml_cuda_mma {
}
};
template <int I_, int J_>
struct tile<I_, J_, nv_bfloat162> {
static constexpr int I = I_;
static constexpr int J = J_;
static constexpr int ne = I * J / WARP_SIZE;
nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
static __device__ __forceinline__ int get_i(const int l) {
if constexpr (I == 8 && J == 8) {
return threadIdx.x / 4;
} else if constexpr (I == 16 && J == 4) {
return l * 8 + threadIdx.x / 4;
} else if constexpr (I == 16 && J == 8) {
return (l % 2) * 8 + threadIdx.x / 4;
} else {
static_assert(I == -1 && J == -1, "template specialization not implemented");
}
}
static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 8 && J == 8) {
return l * 4 + threadIdx.x % 4;
} else if constexpr (I == 16 && J == 4) {
return threadIdx.x % 4;
} else if constexpr (I == 16 && J == 8) {
return (l / 2) * 4 + threadIdx.x % 4;
} else {
static_assert(I == -1 && J == -1, "template specialization not implemented");
}
}
};
template <int I, int J>
static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {
tile<I, J/2, half2> ret;
@ -209,7 +241,7 @@ namespace ggml_cuda_mma {
template <typename T>
static __device__ __forceinline__ void load_ldmatrix(
tile<8, 8, T> & t, const T * __restrict__ xs0, const int stride) {
#ifdef NEW_MMA_AVAILABLE
#ifdef TURING_MMA_AVAILABLE
int * xi = (int *) t.x;
const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + ((threadIdx.x / t.I) * (t.J / 2)) % t.J;
asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
@ -217,13 +249,13 @@ namespace ggml_cuda_mma {
: "l"(xs));
#else
load_generic(t, xs0, stride);
#endif // NEW_MMA_AVAILABLE
#endif // TURING_MMA_AVAILABLE
}
template <typename T>
static __device__ __forceinline__ void load_ldmatrix(
tile<16, 4, T> & t, const T * __restrict__ xs0, const int stride) {
#ifdef NEW_MMA_AVAILABLE
#ifdef TURING_MMA_AVAILABLE
int * xi = (int *) t.x;
const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride;
asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
@ -232,13 +264,13 @@ namespace ggml_cuda_mma {
#else
load_generic(xs0, stride);
GGML_UNUSED(t);
#endif // NEW_MMA_AVAILABLE
#endif // TURING_MMA_AVAILABLE
}
template <typename T>
static __device__ __forceinline__ void load_ldmatrix(
tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
#if defined(NEW_MMA_AVAILABLE)
#if defined(TURING_MMA_AVAILABLE)
int * xi = (int * ) t.x;
const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
asm volatile("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
@ -246,13 +278,13 @@ namespace ggml_cuda_mma {
: "l"(xs));
#else
load_generic(t, xs0, stride);
#endif // NEW_MMA_AVAILABLE
#endif // TURING_MMA_AVAILABLE
}
template <typename T>
static __device__ __forceinline__ void load_ldmatrix_trans(
tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
#ifdef NEW_MMA_AVAILABLE
#ifdef TURING_MMA_AVAILABLE
int * xi = (int * ) t.x;
const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.b16 {%0, %1, %2, %3}, [%4];"
@ -263,12 +295,12 @@ namespace ggml_cuda_mma {
GGML_UNUSED(xs0);
GGML_UNUSED(stride);
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
#endif // TURING_MMA_AVAILABLE
}
static __device__ __forceinline__ void mma(
tile<16, 8, int> & D, const tile<16, 4, int> & A, const tile<8, 4, int> & B) {
#ifdef NEW_MMA_AVAILABLE
#ifdef TURING_MMA_AVAILABLE
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
: "+r"(D.x[0]), "+r"(D.x[1]), "+r"(D.x[2]), "+r"(D.x[3])
@ -287,12 +319,12 @@ namespace ggml_cuda_mma {
GGML_UNUSED(A);
GGML_UNUSED(B);
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
#endif // TURING_MMA_AVAILABLE
}
static __device__ __forceinline__ void mma(
tile<16, 8, int> & D, const tile<16, 8, int> & A, const tile<8, 8, int> & B) {
#ifdef NEW_MMA_AVAILABLE
#ifdef TURING_MMA_AVAILABLE
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
: "+r"(D.x[0]), "+r"(D.x[1]), "+r"(D.x[2]), "+r"(D.x[3])
@ -317,12 +349,12 @@ namespace ggml_cuda_mma {
GGML_UNUSED(A);
GGML_UNUSED(B);
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
#endif // TURING_MMA_AVAILABLE
}
static __device__ __forceinline__ void mma(
tile<16, 4, half2> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {
#ifdef NEW_MMA_AVAILABLE
#ifdef TURING_MMA_AVAILABLE
const int * Axi = (const int *) A.x;
const int * Bxi = (const int *) B.x;
int * Dxi = (int *) D.x;
@ -344,12 +376,12 @@ namespace ggml_cuda_mma {
GGML_UNUSED(A);
GGML_UNUSED(B);
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
#endif // TURING_MMA_AVAILABLE
}
static __device__ __forceinline__ void mma(
tile<16, 8, half2> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) {
#ifdef NEW_MMA_AVAILABLE
#ifdef TURING_MMA_AVAILABLE
const int * Axi = (const int *) A.x;
const int * Bxi = (const int *) B.x;
int * Dxi = (int *) D.x;
@ -380,12 +412,29 @@ namespace ggml_cuda_mma {
GGML_UNUSED(A);
GGML_UNUSED(B);
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
#endif // TURING_MMA_AVAILABLE
}
static __device__ __forceinline__ void mma(
tile<16, 8, float> & D, const tile<16, 8, float> & A, const tile<8, 8, float> & B) {
#ifdef AMPERE_MMA_AVAILABLE
const int * Axi = (const int *) A.x;
const int * Bxi = (const int *) B.x;
int * Dxi = (int *) D.x;
asm("mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
#else
GGML_UNUSED(D);
GGML_UNUSED(A);
GGML_UNUSED(B);
NO_DEVICE_CODE;
#endif // AMPERE_MMA_AVAILABLE
}
static __device__ __forceinline__ void mma(
tile<16, 8, float> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {
#ifdef NEW_MMA_AVAILABLE
#ifdef TURING_MMA_AVAILABLE
const int * Axi = (const int *) A.x;
const int * Bxi = (const int *) B.x;
int * Dxi = (int *) D.x;
@ -407,12 +456,29 @@ namespace ggml_cuda_mma {
GGML_UNUSED(A);
GGML_UNUSED(B);
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
#endif // TURING_MMA_AVAILABLE
}
static __device__ __forceinline__ void mma(
tile<16, 8, float> & D, const tile<16, 8, nv_bfloat162> & A, const tile<8, 8, nv_bfloat162> & B) {
#ifdef AMPERE_MMA_AVAILABLE
const int * Axi = (const int *) A.x;
const int * Bxi = (const int *) B.x;
int * Dxi = (int *) D.x;
asm("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
#else
GGML_UNUSED(D);
GGML_UNUSED(A);
GGML_UNUSED(B);
NO_DEVICE_CODE;
#endif // AMPERE_MMA_AVAILABLE
}
static __device__ __forceinline__ void mma(
tile<16, 16, float> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) {
#ifdef NEW_MMA_AVAILABLE
#ifdef TURING_MMA_AVAILABLE
const int * Axi = (const int *) A.x;
const int * Bxi = (const int *) B.x;
int * Dxi = (int *) D.x;
@ -443,7 +509,7 @@ namespace ggml_cuda_mma {
GGML_UNUSED(A);
GGML_UNUSED(B);
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
#endif // TURING_MMA_AVAILABLE
}
static __device__ __forceinline__ void mma(

431
ggml/src/ggml-cuda/mmf.cu Normal file
View file

@ -0,0 +1,431 @@
#include "ggml.h"
#include "common.cuh"
#include "mma.cuh"
#include "mmf.cuh"
using namespace ggml_cuda_mma;
#define MMF_ROWS_PER_BLOCK 32
template <typename T, int rows_per_block, int cols_per_block, int nwarps>
__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1)
static __global__ void mul_mat_f(
const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
const int ncols, const int nchannels_y, const int stride_row, const int stride_col_y, const int stride_col_dst,
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
typedef tile<16, 8, T> tile_A;
typedef tile< 8, 8, T> tile_B;
typedef tile<16, 8, float> tile_C;
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
constexpr int tile_k_padded = warp_size + 4;
constexpr int ntA = rows_per_block / tile_A::I;
constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I;
const int row0 = blockIdx.x * rows_per_block;
const int channel_dst = blockIdx.y;
const int channel_x = channel_dst / channel_ratio;
const int channel_y = channel_dst;
const int sample_dst = blockIdx.z;
const int sample_x = sample_dst / sample_ratio;
const int sample_y = sample_dst;
x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row0*stride_row ;
y += int64_t(sample_y) *stride_sample_y + channel_y *stride_channel_y;
dst += int64_t(sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst;
const float2 * y2 = (const float2 *) y;
extern __shared__ char data_mmv[];
tile_C C[ntA][ntB];
T * tile_xy = (T *) data_mmv + threadIdx.y*(tile_A::I * tile_k_padded);
for (int col = threadIdx.y*warp_size + threadIdx.x; col < ncols; col += nwarps*warp_size) {
tile_A A[ntA][warp_size / tile_A::J];
#pragma unroll
for (int itA = 0; itA < ntA; ++itA) {
#pragma unroll
for (int i = 0; i < tile_A::I; ++i) {
tile_xy[i*tile_k_padded + threadIdx.x] = x[(itA*tile_A::I + i)*stride_row + col];
}
#pragma unroll
for (int k0 = 0; k0 < warp_size; k0 += tile_A::J) {
load_ldmatrix(A[itA][k0/tile_A::J], tile_xy + k0, tile_k_padded);
}
}
#pragma unroll
for (int itB = 0; itB < ntB; ++itB) {
if constexpr (std::is_same_v<T, float>) {
#pragma unroll
for (int j0 = 0; j0 < tile_B::I; ++j0) {
const int j = j0 + itB*tile_B::I;
tile_xy[j0*tile_k_padded + threadIdx.x] = j < cols_per_block ? y[j*stride_col_y + col] : 0.0f;
}
} else if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {
#pragma unroll
for (int j0 = 0; j0 < tile_B::I; ++j0) {
const int j = j0 + itB*tile_B::I;
const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f);
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
}
} else {
static_assert(std::is_same_v<T, void>, "unsupported type");
}
#pragma unroll
for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) {
tile_B B;
load_ldmatrix(B, tile_xy + k0, tile_k_padded);
#pragma unroll
for (int itA = 0; itA < ntA; ++itA) {
mma(C[itA][itB], A[itA][k0/tile_B::J], B);
}
}
}
}
float * buf_iw = (float *) data_mmv;
constexpr int kiw = nwarps*rows_per_block + 4;
if (nwarps > 1) {
__syncthreads();
}
#pragma unroll
for (int itB = 0; itB < ntB; ++itB) {
#pragma unroll
for (int itA = 0; itA < ntA; ++itA) {
#pragma unroll
for (int l = 0; l < tile_C::ne; ++l) {
const int i = threadIdx.y*rows_per_block + itA*tile_C::I + tile_C::get_i(l);
const int j = itB*tile_C::J + tile_C::get_j(l);
buf_iw[j*kiw + i] = C[itA][itB].x[l];
}
}
}
if (nwarps > 1) {
__syncthreads();
}
#pragma unroll
for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) {
const int j = j0 + threadIdx.y;
if (j0 + nwarps > cols_per_block && j >= cols_per_block) {
return;
}
float sum = 0.0f;
static_assert(rows_per_block == warp_size, "need loop/check");
#pragma unroll
for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) {
const int i = i0 + threadIdx.x;
sum += buf_iw[j*kiw + i];
}
dst[j*stride_col_dst + row0 + threadIdx.x] = sum;
}
#else
NO_DEVICE_CODE;
GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(ids); GGML_UNUSED(dst);
GGML_UNUSED(ncols); GGML_UNUSED(nchannels_y); GGML_UNUSED(stride_row); GGML_UNUSED(stride_col_y); GGML_UNUSED(stride_col_dst);
GGML_UNUSED(channel_ratio); GGML_UNUSED(stride_channel_x); GGML_UNUSED(stride_channel_y); GGML_UNUSED(stride_channel_dst);
GGML_UNUSED(sample_ratio); GGML_UNUSED(stride_sample_x); GGML_UNUSED(stride_sample_y); GGML_UNUSED(stride_sample_dst);
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
}
template <typename T, int cols_per_block>
static void mul_mat_f_cuda(
const T * x, const float * y, const int32_t * ids, float * dst,
const int64_t ncols_x, const int64_t nrows_x,
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
cudaStream_t stream) {
typedef tile<16, 8, T> tile_A;
typedef tile< 8, 8, T> tile_B;
typedef tile<16, 8, float> tile_C;
GGML_ASSERT(!ids && "mul_mat_id not implemented");
GGML_ASSERT(ncols_x % 2 == 0);
GGML_ASSERT(stride_row % 2 == 0);
GGML_ASSERT(stride_col_y % 2 == 0);
GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);
GGML_ASSERT( nsamples_dst % nsamples_x == 0);
const int64_t channel_ratio = nchannels_dst / nchannels_x;
const int64_t sample_ratio = nsamples_dst / nsamples_x;
const int device = ggml_cuda_get_device();
const int warp_size = ggml_cuda_info().devices[device].warp_size;
int64_t nwarps_best = 1;
int64_t niter_best = (ncols_x + warp_size*2 - 1) / (warp_size*2);
int64_t max_block_size = 256;
for (int64_t nwarps = 2; nwarps <= max_block_size/warp_size; nwarps++) {
const int64_t niter = (ncols_x + nwarps*warp_size*2 - 1) / (nwarps*warp_size*2);
if (niter < niter_best) {
niter_best = niter;
nwarps_best = nwarps;
}
}
constexpr int rows_per_block = MMF_ROWS_PER_BLOCK;
const int nbytes_shared_iter = nwarps_best * tile_A::I * (warp_size + 4) * 4;
const int nbytes_shared_combine = GGML_PAD(cols_per_block, tile_B::I) * (nwarps_best*rows_per_block + 4) * 4;
const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine);
const dim3 block_nums(nrows_x/rows_per_block, nchannels_dst, nsamples_dst);
const dim3 block_dims(warp_size, nwarps_best, 1);
switch (nwarps_best) {
case 1: {
mul_mat_f<T, rows_per_block, cols_per_block, 1><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 2: {
mul_mat_f<T, rows_per_block, cols_per_block, 2><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 3: {
mul_mat_f<T, rows_per_block, cols_per_block, 3><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 4: {
mul_mat_f<T, rows_per_block, cols_per_block, 4><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 5: {
mul_mat_f<T, rows_per_block, cols_per_block, 5><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 6: {
mul_mat_f<T, rows_per_block, cols_per_block, 6><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 7: {
mul_mat_f<T, rows_per_block, cols_per_block, 7><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 8: {
mul_mat_f<T, rows_per_block, cols_per_block, 8><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
default: {
GGML_ABORT("fatal error");
} break;
}
}
template <typename T>
static void mul_mat_f_switch_cols_per_block(
const T * x, const float * y, const int32_t * ids, float * dst,
const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst,
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
cudaStream_t stream) {
switch (ncols_dst) {
case 1: {
mul_mat_f_cuda<T, 1>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 2: {
mul_mat_f_cuda<T, 2>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 3: {
mul_mat_f_cuda<T, 3>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 4: {
mul_mat_f_cuda<T, 4>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 5: {
mul_mat_f_cuda<T, 5>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 6: {
mul_mat_f_cuda<T, 6>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 7: {
mul_mat_f_cuda<T, 7>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 8: {
mul_mat_f_cuda<T, 8>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 9: {
mul_mat_f_cuda<T, 9>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 10: {
mul_mat_f_cuda<T, 10>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 11: {
mul_mat_f_cuda<T, 11>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 12: {
mul_mat_f_cuda<T, 12>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 13: {
mul_mat_f_cuda<T, 13>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 14: {
mul_mat_f_cuda<T, 14>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 15: {
mul_mat_f_cuda<T, 15>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 16: {
mul_mat_f_cuda<T, 16>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
default: {
GGML_ABORT("fatal error");
} break;
}
}
void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
GGML_ASSERT( src1->type == GGML_TYPE_F32);
GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_TENSOR_BINARY_OP_LOCALS;
const size_t ts_src0 = ggml_type_size(src0->type);
const size_t ts_src1 = ggml_type_size(src1->type);
const size_t ts_dst = ggml_type_size(dst->type);
GGML_ASSERT(ne13 == ne3);
GGML_ASSERT( nb00 == ts_src0);
GGML_ASSERT( nb10 == ts_src1);
GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type));
GGML_ASSERT( nb0 == ts_dst);
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
const float * src1_d = (const float *) src1->data;
const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr;
float * dst_d = (float *) dst->data;
const int64_t s01 = src0->nb[1] / ts_src0;
const int64_t s11 = src1->nb[1] / ts_src1;
const int64_t s1 = dst->nb[1] / ts_dst;
const int64_t s02 = src0->nb[2] / ts_src0;
const int64_t s12 = src1->nb[2] / ts_src1;
const int64_t s2 = dst->nb[2] / ts_dst;
const int64_t s03 = src0->nb[3] / ts_src0;
const int64_t s13 = src1->nb[3] / ts_src1;
const int64_t s3 = dst->nb[3] / ts_dst;
// For MUL_MAT_ID the memory layout is different than for MUL_MAT:
const int64_t ncols_dst = ids ? ne2 : ne1;
const int64_t nchannels_y = ids ? ne11 : ne12;
const int64_t nchannels_dst = ids ? ne1 : ne2;
const int64_t stride_channel_dst = ids ? s1 : s2;
const int64_t stride_channel_y = ids ? s11 : s12;
GGML_ASSERT(!ids || ncols_dst == 1);
switch (src0->type) {
case GGML_TYPE_F32: {
const float * src0_d = (const float *) src0->data;
constexpr int vals_per_T = 1;
mul_mat_f_switch_cols_per_block(
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, s11/vals_per_T, s1,
ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream());
} break;
case GGML_TYPE_F16: {
const half2 * src0_d = (const half2 *) src0->data;
constexpr int vals_per_T = 2;
mul_mat_f_switch_cols_per_block(
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, s11/vals_per_T, s1,
ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream());
} break;
case GGML_TYPE_BF16: {
const nv_bfloat162 * src0_d = (const nv_bfloat162 *) src0->data;
constexpr int vals_per_T = 2;
mul_mat_f_switch_cols_per_block(
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, s11/vals_per_T, s1,
ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream());
} break;
default:
GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
}
}
bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * src0_ne, int64_t ne11) {
if (src0_ne[0] % (warp_size * (4/ggml_type_size(type))) != 0) {
return false;
}
if (src0_ne[1] % MMF_ROWS_PER_BLOCK != 0) {
return false;
}
if (ne11 > 16) {
return false;
}
switch (type) {
case GGML_TYPE_F32:
return ampere_mma_available(cc);
case GGML_TYPE_F16:
return turing_mma_available(cc);
case GGML_TYPE_BF16:
return ampere_mma_available(cc);
default:
return false;
}
}

View file

@ -0,0 +1,5 @@
#include "common.cuh"
void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, int64_t ne11);

View file

@ -312,7 +312,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
return false;
}
if (new_mma_available(cc)) {
if (turing_mma_available(cc)) {
return true;
}

View file

@ -93,7 +93,7 @@ struct tile_x_sizes {
};
static int get_mmq_x_max_host(const int cc) {
return (amd_mfma_available(cc) || new_mma_available(cc)) ? 128 :
return (amd_mfma_available(cc) || turing_mma_available(cc)) ? 128 :
GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ?
#ifdef GGML_CUDA_FORCE_MMQ
128 : 64;
@ -103,9 +103,9 @@ static int get_mmq_x_max_host(const int cc) {
}
static constexpr __device__ int get_mmq_x_max_device() {
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
return 128;
#else // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#else // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
#if defined(GGML_USE_HIP)
return 64;
@ -122,7 +122,7 @@ static constexpr __device__ int get_mmq_x_max_device() {
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
#endif // defined(GGML_USE_HIP)
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
}
static int get_mmq_y_host(const int cc) {
@ -234,7 +234,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
static int mmq_get_granularity_host(const int mmq_x, const int cc) {
if (amd_mfma_available(cc)) {
return mmq_x >= 128 ? 32 : 16;
} else if (new_mma_available(cc) && mmq_x >= 48) {
} else if (turing_mma_available(cc) && mmq_x >= 48) {
return 16;
} else {
return 8;
@ -245,7 +245,7 @@ static int mmq_get_granularity_host(const int mmq_x, const int cc) {
static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) {
return mmq_x >= 128 ? 32 : 16;
}
#elif defined(NEW_MMA_AVAILABLE)
#elif defined(TURING_MMA_AVAILABLE)
static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) {
return mmq_x >= 48 ? 16 : 8;
}
@ -280,14 +280,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
constexpr int nwarps = mmq_get_nwarps_device();
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + 2*MMQ_TILE_NE_K);
#else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + txs.qs);
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_0);
constexpr int nrows = warp_size / threads_per_row;
@ -306,12 +306,12 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx;
const int qs0 = get_int_b2(bxi->qs, kqsx);
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + 0] = __vsubss4((qs0 >> 0) & 0x0F0F0F0F, 0x08080808);
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + QI4_0] = __vsubss4((qs0 >> 4) & 0x0F0F0F0F, 0x08080808);
#else
x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0;
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
}
constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_0;
@ -328,11 +328,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbxd;
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
#else
x_df[i*(MMQ_TILE_NE_K/QI4_0) + i/QI4_0 + kbxd] = bxi->d;
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
}
}
@ -383,14 +383,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
constexpr int nwarps = mmq_get_nwarps_device();
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
int * x_qs = (int *) x_tile;
half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
#else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
int * x_qs = (int *) x_tile;
half2 * x_dm = (half2 *) (x_qs + txs.qs);
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_1);
constexpr int nrows = warp_size / threads_per_row;
@ -409,12 +409,12 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx;
const int qs0 = get_int_b4(bxi->qs, kqsx);
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + 0] = (qs0 >> 0) & 0x0F0F0F0F;
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + QI4_1] = (qs0 >> 4) & 0x0F0F0F0F;
#else
x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0;
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
}
constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_1;
@ -431,11 +431,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbxd;
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm;
#else
x_dm[i*(MMQ_TILE_NE_K/QI4_1) + i/QI4_1 + kbxd] = bxi->dm;
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
}
}
@ -486,14 +486,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
constexpr int nwarps = mmq_get_nwarps_device();
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
#else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y);
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + txs.qs);
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_0);
constexpr int nrows = warp_size / threads_per_row;
@ -528,13 +528,13 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
qs1 = __vsubss4(qs1, 0x10101010); // subtract 16
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + 0] = qs0;
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
#else
x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_0) + kqsx + 0] = qs0;
x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
}
constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI5_0;
@ -551,11 +551,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd;
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
#else
x_df[i*(MMQ_TILE_NE_K/QI5_0) + i/QI5_0 + kbxd] = bxi->d;
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
}
}
@ -564,14 +564,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
constexpr int nwarps = mmq_get_nwarps_device();
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
int * x_qs = (int *) x_tile;
half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
#else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);
int * x_qs = (int *) x_tile;
half2 * x_dm = (half2 *) (x_qs + txs.qs);
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_1);
constexpr int nrows = warp_size / threads_per_row;
@ -604,13 +604,13 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
qs1 |= (qh << 2) & 0x00100000; // 18 -> 20
qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + 0] = qs0;
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
#else
x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_1) + kqsx + 0] = qs0;
x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
}
constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI5_1;
@ -627,11 +627,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd;
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm;
#else
x_dm[i*(MMQ_TILE_NE_K/QI5_1) + i/QI5_1 + kbxd] = bxi->dm;
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
}
}
@ -640,14 +640,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
constexpr int nwarps = mmq_get_nwarps_device();
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_tile + 2*MMQ_TILE_NE_K);
#else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + txs.qs);
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
// MMQ_ITER_K / (4 * QR8_0) == 64 required. but NV has only 32 threads per warp
constexpr int threads_per_row = 32;
@ -666,13 +666,13 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx;
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 0 + txi] = get_int_b2(bxi[0].qs, kqsx);
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + MMQ_TILE_NE_K + txi] = get_int_b2(bxi[MMQ_TILE_NE_K/QI8_0].qs, kqsx);
#else
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 0 + txi] = get_int_b2(bxi[0].qs, kqsx);
x_qs[i*(2*MMQ_TILE_NE_K + 1) + MMQ_TILE_NE_K + txi] = get_int_b2(bxi[MMQ_TILE_NE_K/QI8_0].qs, kqsx);
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
}
constexpr int blocks_per_tile_x_row = 2*MMQ_TILE_NE_K / QI8_0;
@ -689,11 +689,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd;
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
#else
x_df[i*(2*MMQ_TILE_NE_K/QI8_0) + i/(QI8_0/2) + kbxd] = bxi->d;
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
}
}
@ -702,14 +702,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
constexpr int nwarps = mmq_get_nwarps_device();
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
#else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_MXFP4, mmq_y);
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + txs.qs);
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR_MXFP4);
constexpr int nrows = warp_size / threads_per_row;
@ -731,13 +731,13 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
const int2 v = get_int_from_table_16(aux_q4, kvalues_mxfp4);
const int k0 = kbx * (2 * QI_MXFP4) + kqsx;
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + 0] = v.x;
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + QI_MXFP4] = v.y;
#else
x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x;
x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + QI_MXFP4] = v.y;
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
}
constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI_MXFP4;
@ -754,11 +754,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i*stride + kbxd;
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
x_df[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f;
#else
x_df[i*(MMQ_TILE_NE_K/QI_MXFP4) + i/QI_MXFP4 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f;
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
}
}
@ -1179,7 +1179,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
}
}
}
#elif defined(NEW_MMA_AVAILABLE)
#elif defined(TURING_MMA_AVAILABLE)
typedef tile<16, 4, int> tile_A;
typedef tile<16, 8, int> tile_A_8;
@ -1265,14 +1265,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
constexpr int nwarps = mmq_get_nwarps_device();
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
int * x_qs = (int *) x_tile;
half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
#else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
int * x_qs = (int *) x_tile;
half2 * x_dm = (half2 *) (x_qs + txs.qs);
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR2_K);
constexpr int nrows = ggml_cuda_get_physical_warp_size() / threads_per_row;
@ -1296,11 +1296,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
const int x_qs_k = (x_ql_0 >> (2*l)) & 0x03030303;
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k] = x_qs_k;
#else
x_qs[i*(2*MMQ_TILE_NE_K + 1) + k] = x_qs_k;
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
}
const int sc_m = bxi->scales[kqsx];
@ -1311,11 +1311,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
const half2 x_dm_ik = make_half2(bxi_dmf.x*(sc_m & 0x0F), bxi_dmf.y*(sc_m >> 4));
#endif // FAST_FP16_AVAILABLE
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + kqsx] = x_dm_ik;
#else
x_dm[i*(MMQ_TILE_NE_K + 1) + kqsx] = x_dm_ik;
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
}
}
@ -1453,7 +1453,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
}
}
}
#elif defined(NEW_MMA_AVAILABLE)
#elif defined(TURING_MMA_AVAILABLE)
typedef tile<16, 4, int> tile_A;
typedef tile<16, 8, int> tile_A_8;
@ -1583,7 +1583,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
constexpr int nwarps = mmq_get_nwarps_device();
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
#else
@ -1591,7 +1591,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + txs.qs);
int * x_sc = (int *) (x_df + txs.dm);
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR3_K);
constexpr int nrows = warp_size / threads_per_row;
@ -1619,11 +1619,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
const int x_qs_k = __vsubss4(x_ql_k | x_qh_k, 0x04040404);
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k] = x_qs_k;
#else
x_qs[i*(2*MMQ_TILE_NE_K + 1) + k] = x_qs_k;
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
}
}
@ -1650,7 +1650,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
const int sc = __vsubss4(sc_low | sc_high, 0x20202020);
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
const int8_t * sc8 = (const int8_t *) &sc;
const float d = bxi->d;
@ -1660,10 +1660,10 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
}
#else
x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = sc;
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
}
#if !(defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE))
#if !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE))
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) {
int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y;
@ -1676,7 +1676,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
x_df[i] = bxi->d;
}
#endif // !(defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE))
#endif // !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE))
}
template <int mmq_x, int mmq_y>
@ -1729,7 +1729,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
constexpr int nwarps = mmq_get_nwarps_device();
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
int * x_qs = (int *) x_tile;
half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
#else
@ -1737,7 +1737,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
int * x_qs = (int *) x_tile;
half2 * x_dm = (half2 *) (x_qs + txs.qs);
int * x_sc = (int *) (x_dm + txs.dm);
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_K);
constexpr int nrows = warp_size / threads_per_row;
@ -1754,15 +1754,15 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride;
const int qs0 = get_int_b4(bxi->qs, txi);
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(txi/8) + txi % 8 + 0] = (qs0 >> 0) & 0x0F0F0F0F;
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(txi/8) + txi % 8 + 8] = (qs0 >> 4) & 0x0F0F0F0F;
#else
x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0;
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
}
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
constexpr int rows_per_warp = warp_size / 2;
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
@ -1830,7 +1830,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = scales8;
}
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
}
template <int mmq_x, int mmq_y>
@ -1873,7 +1873,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
constexpr int nwarps = mmq_get_nwarps_device();
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
int * x_qs = (int *) x_tile;
half2 * x_dm = (half2 *) (x_qs + MMQ_TILE_NE_K*2);
#else
@ -1881,7 +1881,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
int * x_qs = (int *) x_tile;
half2 * x_dm = (half2 *) (x_qs + txs.qs);
int * x_sc = (int *) (x_dm + txs.dm);
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_K);
constexpr int nrows = warp_size / threads_per_row;
@ -1909,16 +1909,16 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
const int kq0 = ky - ky % (QI5_K/2) + txi % (QI5_K/4) + 0;
const int kq1 = ky - ky % (QI5_K/2) + txi % (QI5_K/4) + QI5_K/4;
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq0] = ql0 | qh0;
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq1] = ql1 | qh1;
#else
x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq0] = ql0 | qh0;
x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq1] = ql1 | qh1;
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
}
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
constexpr int rows_per_warp = warp_size / 2;
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
@ -1987,7 +1987,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = scales8;
}
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
}
template <int mmq_x, int mmq_y>
@ -2030,7 +2030,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
constexpr int nwarps = mmq_get_nwarps_device();
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
int * x_sc = (int *) (x_df + MMQ_TILE_NE_K/QI6_K);
@ -2039,7 +2039,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + txs.qs);
int * x_sc = (int *) (x_df + txs.dm);
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR6_K);
constexpr int nrows = warp_size / threads_per_row;
@ -2066,13 +2066,13 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
const int kq0 = 2*txi - txi % (QI6_K/2) + 0;
const int kq1 = 2*txi - txi % (QI6_K/2) + QI6_K/2;
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
#else
x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
}
#pragma unroll
@ -2085,11 +2085,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride;
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
x_df[i*MMQ_MMA_TILE_X_K_Q6_K] = bxi->d;
#else
x_df[i*(MMQ_TILE_NE_K/QI6_K) + i/QI6_K] = bxi->d;
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
}
constexpr int rows_per_warp = warp_size / 4;
@ -2103,11 +2103,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (MMQ_TILE_NE_K/8)) / 4;
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x%4] = get_int_b2(bxi->scales, threadIdx.x % (MMQ_TILE_NE_K/8));
#else
x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + threadIdx.x%(MMQ_TILE_NE_K/8)] = get_int_b2(bxi->scales, threadIdx.x%(QI6_K/8));
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
}
}
@ -2200,7 +2200,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
}
}
}
#elif defined(NEW_MMA_AVAILABLE)
#elif defined(TURING_MMA_AVAILABLE)
typedef tile<16, 4, int> tile_A;
typedef tile< 8, 4, int> tile_B;
@ -2312,14 +2312,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
constexpr int nwarps = mmq_get_nwarps_device();
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
#else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_NL, mmq_y);
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + txs.qs);
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_NL);
constexpr int nrows = warp_size / threads_per_row;
@ -2341,13 +2341,13 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
const int k0 = kbx * (2 * QI4_NL) + kqsx;
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + QI4_NL] = v.y;
#else
x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x;
x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + QI4_NL] = v.y;
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
}
constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_NL;
@ -2364,11 +2364,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbxd;
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = __half2float(bxi->d);
#else
x_df[i*(MMQ_TILE_NE_K/QI4_NL) + i/QI4_NL + kbxd] = __half2float(bxi->d);
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
}
}
@ -2377,14 +2377,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
constexpr int nwarps = mmq_get_nwarps_device();
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
#else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_XXS, mmq_y);
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + txs.qs);
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XXS)) / 2;
constexpr int nrows = warp_size / threads_per_row;
@ -2415,22 +2415,22 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000);
const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1);
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid0;
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid1;
#else
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid0;
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid1;
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
}
const int ls = aux32 >> 28;
const float d = bxi->d;
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/4;
#else
x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/4;
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
}
}
@ -2439,14 +2439,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
constexpr int nwarps = mmq_get_nwarps_device();
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
#else
constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16;
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + txs.qs);
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XS)) / 2;
constexpr int nrows = warp_size / threads_per_row;
@ -2473,24 +2473,24 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]);
const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]);
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h;
#else
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l;
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h;
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
}
const int ls = bxi->scales[kqsx];
const float d = bxi->d;
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
#else
x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
}
}
@ -2499,14 +2499,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
constexpr int nwarps = mmq_get_nwarps_device();
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
#else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_S, mmq_y);
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + txs.qs);
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_S)) / 2;
constexpr int nrows = warp_size / threads_per_row;
@ -2540,24 +2540,24 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
const int grid_l = __vsub4(grid_pos[0] ^ signs0, signs0);
const int grid_h = __vsub4(grid_pos[1] ^ signs1, signs1);
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h;
#else
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l;
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h;
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
}
const int ls = bxi->scales[kqsx];
const float d = bxi->d;
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
#else
x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
}
}
@ -2566,14 +2566,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
constexpr int nwarps = mmq_get_nwarps_device();
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
#else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_XXS, mmq_y);
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + txs.qs);
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR3_XXS)) / 2;
constexpr int nrows = warp_size / threads_per_row;
@ -2602,22 +2602,22 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]);
const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]);
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid_l;
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid_h;
#else
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l;
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h;
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
}
const int ls = aux32 >> 28;
const float d = bxi->d;
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/2;
#else
x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/2;
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
}
}
@ -2626,14 +2626,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
constexpr int nwarps = mmq_get_nwarps_device();
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
#else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y);
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + txs.qs);
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR3_S)) / 2;
constexpr int nrows = warp_size / threads_per_row;
@ -2669,22 +2669,22 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+0)] = grid_l;
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+1)] = grid_h;
#else
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+0)] = grid_l;
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+1)] = grid_h;
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
}
const int ls = 1 + 2*((bxi->scales[kqsx/2] >> (((2*kqsx) << 1) & 0x04)) & 0x0F);
const float d = bxi->d;
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = ls*d;
#else
x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = ls*d;
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
}
}
@ -2693,14 +2693,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
constexpr int nwarps = mmq_get_nwarps_device();
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
int * x_qs = (int *) x_tile;
half2 * x_ds = (half2 *) (x_qs + MMQ_TILE_NE_K*2);
#else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y);
int * x_qs = (int *) x_tile;
half2 * x_ds = (half2 *) (x_qs + txs.qs);
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR1_S);
constexpr int nrows = warp_size / threads_per_row;
@ -2728,23 +2728,23 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
const int grid0 = (grid >> 0) & 0x0F0F0F0F;
const int grid1 = (grid >> 4) & 0x0F0F0F0F;
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+0)] = grid0;
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+1)] = grid1;
#else
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+0)] = grid0;
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+1)] = grid1;
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
}
const float d1q = __half2float(bxi->d) * (((qh >> 11) & 0x0E) + 1);
const float delta = -1.0f + IQ1S_DELTA - (qh & 0x8000) * (2.0f*IQ1S_DELTA/0x8000);
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
x_ds[i*MMQ_MMA_TILE_X_K_Q8_1 + kqsx] = make_half2(d1q, d1q*delta);
#else
x_ds[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = make_half2(d1q, d1q*delta);
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
}
}
@ -2753,14 +2753,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
constexpr int nwarps = mmq_get_nwarps_device();
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
#else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y);
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + txs.qs);
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_XS);
constexpr int nrows = warp_size / threads_per_row;
@ -2780,13 +2780,13 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
const int k0 = 8 * (kqsx / 4) + kqsx % 4;
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y;
#else
x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x;
x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 4] = v.y;
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
}
constexpr int rows_per_warp = warp_size / 8;
@ -2805,11 +2805,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
const int ls = ((bxi->scales_l[(threadIdx.x % 8)/2] >> (4*(threadIdx.x % 2))) & 0x0F)
| (((bxi->scales_h >> (2*(threadIdx.x % 8))) & 0x03) << 4);
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * (ls - 32);
#else
x_df[i*(MMQ_TILE_NE_K/4) + i/4 + threadIdx.x % 8] = d * (ls - 32);
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
}
}
@ -2860,9 +2860,9 @@ static __device__ __forceinline__ void mmq_write_back_mma(
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
const int i0 = (threadIdx.y / ntx) * (ntx*tile_C::I);
#if defined(NEW_MMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
#if defined(TURING_MMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
static_assert(nwarps*tile_C::I == mmq_y, "nwarps*tile_C::I != mmq_y");
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
#pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
@ -3062,13 +3062,13 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
int * tile_y = data_mul_mat_q + mmq_x;
int * tile_x = tile_y + GGML_PAD(mmq_x*MMQ_TILE_Y_K, nwarps*warp_size);
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, need_check, type>::vec_dot_mma;
constexpr mmq_write_back_t write_back = mmq_write_back_mma<type, mmq_x, mmq_y, need_check>;
#else
constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, need_check, type>::vec_dot_dp4a;
constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, need_check>;
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
constexpr int blocks_per_iter = MMQ_ITER_K / qk;
@ -3535,7 +3535,7 @@ static size_t mmq_get_nbytes_shared(const int mmq_x, const int mmq_y, const int
const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y);
const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type);
const size_t nbs_ids = mmq_x*sizeof(int);
const size_t nbs_x = (new_mma_available(cc) || amd_mfma_available(cc)) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
const size_t nbs_x = (turing_mma_available(cc) || amd_mfma_available(cc)) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
const size_t nbs_y = mmq_x*sizeof(block_q8_1_mmq);
return nbs_ids + nbs_x + GGML_PAD(nbs_y, nwarps*warp_size*sizeof(int));
}

View file

@ -1,9 +1,9 @@
#include "ggml.h"
#include "common.cuh"
#include "mmv.cuh"
#include "mmvf.cuh"
template <typename T, typename type_acc, int ncols_dst, int block_size>
static __global__ void mul_mat_vec(
static __global__ void mul_mat_vec_f(
const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
const int ncols2, const int nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst,
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
@ -37,7 +37,7 @@ static __global__ void mul_mat_vec(
float sumf[ncols_dst] = {0.0f};
if constexpr (std::is_same<T, float>::value) {
if constexpr (std::is_same_v<T, float>) {
const float2 * x2 = (const float2 *) x;
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
@ -50,10 +50,10 @@ static __global__ void mul_mat_vec(
sumf[j] += tmpx.y*tmpy.y;
}
}
} else if constexpr (std::is_same<T, half>::value) {
} else if constexpr (std::is_same_v<T, half>) {
const half2 * x2 = (const half2 *) x;
if (std::is_same<type_acc, float>::value) {
if (std::is_same_v<type_acc, float>) {
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
const float2 tmpx = __half22float2(x2[col2]);
@ -86,7 +86,7 @@ static __global__ void mul_mat_vec(
NO_DEVICE_CODE;
#endif // FP16_AVAILABLE
}
} else if constexpr (std::is_same<T, nv_bfloat16>::value) {
} else if constexpr (std::is_same_v<T, nv_bfloat16>) {
const int * x2 = (const int *) x;
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
const int tmpx = x2[col2];
@ -98,7 +98,7 @@ static __global__ void mul_mat_vec(
}
}
} else {
static_assert(std::is_same<T, void>::value, "unsupported type");
static_assert(std::is_same_v<T, void>, "unsupported type");
}
#pragma unroll
@ -126,7 +126,7 @@ static __global__ void mul_mat_vec(
}
template <typename T, typename type_acc, int ncols_dst>
static void launch_mul_mat_vec_cuda(
static void launch_mul_mat_vec_f_cuda(
const T * x, const float * y, const int32_t * ids, float * dst,
const int64_t ncols, const int64_t nrows,
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
@ -141,11 +141,9 @@ static void launch_mul_mat_vec_cuda(
GGML_ASSERT( nsamples_dst % nsamples_x == 0);
const int64_t channel_ratio = nchannels_dst / nchannels_x;
const int64_t sample_ratio = nsamples_dst / nsamples_x;
int device;
int warp_size;
CUDA_CHECK(cudaGetDevice(&device));
warp_size = ggml_cuda_info().devices[device].warp_size;
const int device = ggml_cuda_get_device();
const int warp_size = ggml_cuda_info().devices[device].warp_size;
int64_t block_size_best = warp_size;
int64_t niter_best = (ncols + 2*warp_size - 1) / (2*warp_size);
@ -161,54 +159,54 @@ static void launch_mul_mat_vec_cuda(
}
}
const int smem = warp_size*sizeof(float);
const int nbytes_shared = warp_size*sizeof(float);
const dim3 block_nums(nrows, nchannels_dst, nsamples_dst);
const dim3 block_dims(block_size_best, 1, 1);
switch (block_size_best) {
case 32: {
mul_mat_vec<T, type_acc, ncols_dst, 32><<<block_nums, block_dims, smem, stream>>>
mul_mat_vec_f<T, type_acc, ncols_dst, 32><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 64: {
mul_mat_vec<T, type_acc, ncols_dst, 64><<<block_nums, block_dims, smem, stream>>>
mul_mat_vec_f<T, type_acc, ncols_dst, 64><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 96: {
mul_mat_vec<T, type_acc, ncols_dst, 96><<<block_nums, block_dims, smem, stream>>>
mul_mat_vec_f<T, type_acc, ncols_dst, 96><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 128: {
mul_mat_vec<T, type_acc, ncols_dst, 128><<<block_nums, block_dims, smem, stream>>>
mul_mat_vec_f<T, type_acc, ncols_dst, 128><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 160: {
mul_mat_vec<T, type_acc, ncols_dst, 160><<<block_nums, block_dims, smem, stream>>>
mul_mat_vec_f<T, type_acc, ncols_dst, 160><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 192: {
mul_mat_vec<T, type_acc, ncols_dst, 192><<<block_nums, block_dims, smem, stream>>>
mul_mat_vec_f<T, type_acc, ncols_dst, 192><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 224: {
mul_mat_vec<T, type_acc, ncols_dst, 224><<<block_nums, block_dims, smem, stream>>>
mul_mat_vec_f<T, type_acc, ncols_dst, 224><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 256: {
mul_mat_vec<T, type_acc, ncols_dst, 256><<<block_nums, block_dims, smem, stream>>>
mul_mat_vec_f<T, type_acc, ncols_dst, 256><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
@ -220,7 +218,7 @@ static void launch_mul_mat_vec_cuda(
}
template <typename T, typename type_acc>
static void mul_mat_vec_cuda_switch_ncols_dst(
static void mul_mat_vec_f_cuda_switch_ncols_dst(
const T * x, const float * y, const int32_t * ids, float * dst,
const int64_t ncols, const int64_t nrows, const int64_t ncols_dst,
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
@ -230,49 +228,49 @@ static void mul_mat_vec_cuda_switch_ncols_dst(
cudaStream_t stream) {
switch (ncols_dst) {
case 1:
launch_mul_mat_vec_cuda<T, type_acc, 1>
launch_mul_mat_vec_f_cuda<T, type_acc, 1>
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case 2:
launch_mul_mat_vec_cuda<T, type_acc, 2>
launch_mul_mat_vec_f_cuda<T, type_acc, 2>
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case 3:
launch_mul_mat_vec_cuda<T, type_acc, 3>
launch_mul_mat_vec_f_cuda<T, type_acc, 3>
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case 4:
launch_mul_mat_vec_cuda<T, type_acc, 4>
launch_mul_mat_vec_f_cuda<T, type_acc, 4>
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case 5:
launch_mul_mat_vec_cuda<T, type_acc, 5>
launch_mul_mat_vec_f_cuda<T, type_acc, 5>
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case 6:
launch_mul_mat_vec_cuda<T, type_acc, 6>
launch_mul_mat_vec_f_cuda<T, type_acc, 6>
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case 7:
launch_mul_mat_vec_cuda<T, type_acc, 7>
launch_mul_mat_vec_f_cuda<T, type_acc, 7>
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case 8:
launch_mul_mat_vec_cuda<T, type_acc, 8>
launch_mul_mat_vec_f_cuda<T, type_acc, 8>
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
@ -284,7 +282,7 @@ static void mul_mat_vec_cuda_switch_ncols_dst(
}
template<typename T>
static void mul_mat_vec_cuda(
static void mul_mat_vec_f_cuda(
const T * x, const float * y, const int32_t * ids, float * dst,
const int64_t ncols, const int64_t nrows, const int64_t ncols_dst,
const int64_t stride_row, const int64_t stride_col_y, const int stride_col_dst,
@ -292,22 +290,22 @@ static void mul_mat_vec_cuda(
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
enum ggml_prec prec, cudaStream_t stream) {
if constexpr(std::is_same<T, half>::value) {
if constexpr(std::is_same_v<T, half>) {
if (prec == GGML_PREC_DEFAULT) {
mul_mat_vec_cuda_switch_ncols_dst<T, half>
mul_mat_vec_f_cuda_switch_ncols_dst<T, half>
(x, y, ids, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
return;
}
}
mul_mat_vec_cuda_switch_ncols_dst<T, float>
mul_mat_vec_f_cuda_switch_ncols_dst<T, float>
(x, y, ids, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
}
void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
GGML_ASSERT( src1->type == GGML_TYPE_F32);
GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
@ -355,19 +353,19 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor *
switch (src0->type) {
case GGML_TYPE_F32: {
const float * src0_d = (const float *) src0->data;
mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
ne03, ne3, s03, s13, s3, prec, ctx.stream());
} break;
case GGML_TYPE_F16: {
const half * src0_d = (const half *) src0->data;
mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
ne03, ne3, s03, s13, s3, prec, ctx.stream());
} break;
case GGML_TYPE_BF16: {
const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data;
mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
ne03, ne3, s03, s13, s3, prec, ctx.stream());
} break;
@ -376,7 +374,7 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor *
}
}
void ggml_cuda_op_mul_mat_vec(
void ggml_cuda_op_mul_mat_vec_f(
ggml_backend_cuda_context & ctx,
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
@ -414,19 +412,19 @@ void ggml_cuda_op_mul_mat_vec(
switch (src0->type) {
case GGML_TYPE_F32: {
const float * src0_d = (const float *) src0_dd_i;
mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
} break;
case GGML_TYPE_F16: {
const half * src0_d = (const half *) src0_dd_i;
mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
} break;
case GGML_TYPE_BF16: {
const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i;
mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
} break;
@ -442,15 +440,15 @@ void ggml_cuda_op_mul_mat_vec(
GGML_UNUSED(src1_padded_row_size);
}
bool ggml_cuda_should_use_mmv(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11) {
bool ggml_cuda_should_use_mmvf(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11) {
if (src0_ne[0] % 2 != 0) {
return false;
}
switch (type) {
case GGML_TYPE_F32:
if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
return ne11 <= 8;
if (ampere_mma_available(cc)) {
return ne11 <= 3;
}
if (cc >= GGML_CUDA_CC_TURING) {
return ne11 <= 4;
@ -466,6 +464,9 @@ bool ggml_cuda_should_use_mmv(enum ggml_type type, int cc, const int64_t * src0_
case GGML_TYPE_F16:
if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
const bool src0_small = (src0_ne[1] <= 512 || src0_ne[2]*src0_ne[3] == 1);
if (ampere_mma_available(cc)) {
return src0_small && ne11 == 1;
}
if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
return src0_small && ne11 <= 4;
}
@ -486,6 +487,9 @@ bool ggml_cuda_should_use_mmv(enum ggml_type type, int cc, const int64_t * src0_
case GGML_TYPE_BF16:
if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
const bool src0_small = (src0_ne[1] <= 512 || src0_ne[2]*src0_ne[3] == 1);
if (ampere_mma_available(cc)) {
return src0_small && ne11 == 1;
}
if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
return src0_small && ne11 <= 4;
}

View file

@ -1,11 +1,11 @@
#include "common.cuh"
void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
void ggml_cuda_op_mul_mat_vec(
void ggml_cuda_op_mul_mat_vec_f(
ggml_backend_cuda_context & ctx,
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
const int64_t src1_padded_row_size, cudaStream_t stream);
bool ggml_cuda_should_use_mmv(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11);
bool ggml_cuda_should_use_mmvf(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11);

View file

@ -200,6 +200,7 @@
#endif
typedef hip_bfloat16 nv_bfloat16;
typedef short2 nv_bfloat162; // FIXME there is no 2x BF16 type being defined in bfloat16.h, ad-hoc compilation fix
typedef int8_t int8x4_t __attribute__((ext_vector_type(4)));
typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4)));

View file

@ -137,4 +137,5 @@
#define cudaStreamEndCapture musaStreamEndCapture
#define cudaOccupancyMaxActiveBlocksPerMultiprocessor musaOccupancyMaxActiveBlocksPerMultiprocessor
typedef mt_bfloat16 nv_bfloat16;
typedef __mt_bfloat16 nv_bfloat16;
typedef __mt_bfloat162 nv_bfloat162;

View file

@ -0,0 +1,42 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
//------------------------------------------------------------------------------
// add_id
//------------------------------------------------------------------------------
kernel void kernel_add_id(
global char * src0,
ulong offset0,
global char * src1,
ulong offset1,
global char * src2,
ulong offset2,
global char * dst,
ulong offsetd,
ulong nb01,
ulong nb02,
ulong nb11,
ulong nb21,
int ne0,
int ne1
) {
src0 = (global char*)((global char*)src0 + offset0);
src1 = (global char*)((global char*)src1 + offset1);
src2 = (global char*)((global char*)src2 + offset2);
dst = (global char*)((global char*)dst + offsetd);
int i1 = get_group_id(0);
int i2 = get_group_id(1);
const int i11 = *((global const int *) (src2 + i1*sizeof(int) + i2*nb21));
const size_t nb1 = ne0 * sizeof(float);
const size_t nb2 = ne1 * nb1;
global float * dst_row = (global float *)((global char *)dst + i1*nb1 + i2*nb2);
global float * src0_row = (global float *)((global char *)src0 + i1*nb01 + i2*nb02);
global float * src1_row = (global float *)((global char *)src1 + i11*nb11);
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
dst_row[i0] = src0_row[i0] + src1_row[i0];
}
}

View file

@ -202,6 +202,47 @@ kernel void kernel_swiglu_f16(
}
}
//------------------------------------------------------------------------------
// swiglu_oai
//------------------------------------------------------------------------------
kernel void kernel_swiglu_oai(
global char * src0,
ulong offset0,
global char * src1,
ulong offset1,
global char * dst,
ulong offsetd,
ulong nb01,
ulong nb11,
int ne0,
ulong nb1,
int ne00_off,
int ne10_off,
float limit,
float alpha
) {
src0 = (global char*)((global char*)src0 + offset0);
src1 = (global char*)((global char*)src1 + offset1);
dst = (global char*)((global char*)dst + offsetd);
global float * src0_row = (global float *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
global float * src1_row = (global float *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
global float * dst_row = (global float *) ((global char *) dst + get_group_id(0)*nb1);
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
float x0 = src0_row[i0];
float x1 = src1_row[i0];
x0 = min(x0, limit);
x1 = max(min(x1, limit), -limit);
float out_glu = x0 / (1.0f + exp(-x0 * alpha));
out_glu = out_glu * (1.0f + x1);
dst_row[i0] = out_glu;
}
}
//------------------------------------------------------------------------------
// geglu_erf
//------------------------------------------------------------------------------

View file

@ -20,6 +20,7 @@
#ifdef GGML_WEBGPU_DEBUG
# define WEBGPU_LOG_DEBUG(msg) std::cout << msg << std::endl
# define WEBGPU_DEBUG_BUF_ELEMS 32
#else
# define WEBGPU_LOG_DEBUG(msg) ((void) 0)
#endif // GGML_WEBGPU_DEBUG
@ -29,7 +30,9 @@
#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 16
#define WEBGPU_MUL_MAT_WG_SIZE 64
#define WEBGPU_NUM_PARAM_BUFS 100
#define WEBGPU_PARAMS_BUF_SIZE_BYTES 256
#define WEBGPU_PARAMS_BUF_SIZE_BYTES 128 // enough for 32 parameters
#define WEBGPU_NUM_SET_ROWS_ERROR_BUFS 32
#define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4
#define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4
/* End Constants */
@ -54,46 +57,42 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device,
wgpu::BufferUsage usage,
const char * label);
struct webgpu_param_bufs {
struct webgpu_pool_bufs {
wgpu::Buffer host_buf;
wgpu::Buffer dev_buf;
};
// Holds a pool of parameter buffers for WebGPU operations
struct webgpu_param_buf_pool {
std::vector<webgpu_param_bufs> free;
struct webgpu_buf_pool {
std::vector<webgpu_pool_bufs> free;
std::mutex mutex;
std::condition_variable cv;
void init(wgpu::Device device) {
for (int i = 0; i < WEBGPU_NUM_PARAM_BUFS; i++) {
void init(wgpu::Device device,
int num_bufs,
size_t buf_size,
wgpu::BufferUsage dev_buf_usage,
wgpu::BufferUsage host_buf_usage) {
for (int i = 0; i < num_bufs; i++) {
wgpu::Buffer host_buf;
wgpu::Buffer dev_buf;
ggml_webgpu_create_buffer(device,
host_buf,
WEBGPU_PARAMS_BUF_SIZE_BYTES,
wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite,
"ggml_webgpu_host_params_buf");
ggml_webgpu_create_buffer(device,
dev_buf,
WEBGPU_PARAMS_BUF_SIZE_BYTES,
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform,
"ggml_webgpu_dev_params_buf");
ggml_webgpu_create_buffer(device, host_buf, buf_size, host_buf_usage, "ggml_webgpu_host_pool_buf");
ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_pool_buf");
free.push_back({ host_buf, dev_buf });
}
}
webgpu_param_bufs alloc_bufs() {
webgpu_pool_bufs alloc_bufs() {
std::unique_lock<std::mutex> lock(mutex);
cv.wait(lock, [this] { return !free.empty(); });
webgpu_param_bufs bufs = free.back();
webgpu_pool_bufs bufs = free.back();
free.pop_back();
return bufs;
}
void free_bufs(std::vector<webgpu_param_bufs> bufs) {
void free_bufs(std::vector<webgpu_pool_bufs> bufs) {
std::lock_guard<std::mutex> lock(mutex);
free.insert(free.end(), bufs.begin(), bufs.end());
cv.notify_all();
@ -121,10 +120,12 @@ struct webgpu_context_struct {
bool device_init = false;
webgpu_param_buf_pool param_buf_pool;
webgpu_buf_pool param_buf_pool;
webgpu_buf_pool set_rows_error_buf_pool;
wgpu::ComputePipeline memset_pipeline;
wgpu::ComputePipeline mul_mat_pipeline;
wgpu::ComputePipeline set_rows_pipeline;
wgpu::ComputePipeline cpy_pipeline;
size_t memset_bytes_per_thread;
@ -136,9 +137,16 @@ struct webgpu_context_struct {
std::vector<wgpu::CommandBuffer> staged_command_bufs;
// Parameter buffers associated with the staged command buffers
std::vector<webgpu_param_bufs> staged_param_bufs;
std::vector<webgpu_pool_bufs> staged_param_bufs;
// Buffers associated with set_rows operations, used to store potential errors
std::vector<webgpu_pool_bufs> staged_set_row_error_bufs;
std::vector<wgpu::FutureWaitInfo> callback_futures;
#ifdef GGML_WEBGPU_DEBUG
wgpu::Buffer debug_host_buf;
wgpu::Buffer debug_dev_buf;
#endif
};
typedef std::shared_ptr<webgpu_context_struct> webgpu_context;
@ -249,20 +257,55 @@ static void ggml_backend_webgpu_submit_queue(webgpu_context & ctx) {
return;
}
ctx->queue.Submit(ctx->staged_command_bufs.size(), ctx->staged_command_bufs.data());
// If there are SET_ROWS operations in this submission, copy their error buffers to the host.
if (ctx->staged_set_row_error_bufs.size() > 0) {
wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder();
for (auto & error_bufs : ctx->staged_set_row_error_bufs) {
// Copy the error buffer to the host buffer
encoder.CopyBufferToBuffer(error_bufs.dev_buf, 0, error_bufs.host_buf, 0, error_bufs.host_buf.GetSize());
}
wgpu::CommandBuffer commands = encoder.Finish();
ctx->queue.Submit(1, &commands);
}
ctx->staged_command_bufs.clear();
std::vector<webgpu_param_bufs> staged_param_bufs = std::move(ctx->staged_param_bufs);
std::vector<webgpu_pool_bufs> staged_param_bufs = std::move(ctx->staged_param_bufs);
std::vector<webgpu_pool_bufs> staged_set_row_error_bufs = std::move(ctx->staged_set_row_error_bufs);
// Free the staged parameter buffers once the submission completes
wgpu::Future f = ctx->queue.OnSubmittedWorkDone(
wgpu::Future p_f = ctx->queue.OnSubmittedWorkDone(
wgpu::CallbackMode::AllowSpontaneous,
[ctx, staged_param_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
if (status != wgpu::QueueWorkDoneStatus::Success) {
GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", message.data);
}
// Free the staged parameter buffers
// Free the staged buffers
ctx->param_buf_pool.free_bufs(staged_param_bufs);
});
ctx->callback_futures.push_back({ p_f });
// Check for errrors in SET_ROWS operations
for (auto & error_bufs : staged_set_row_error_bufs) {
wgpu::Future f = error_bufs.host_buf.MapAsync(
wgpu::MapMode::Read,
0,
error_bufs.host_buf.GetSize(),
wgpu::CallbackMode::AllowSpontaneous,
[ctx, error_bufs](wgpu::MapAsyncStatus status, wgpu::StringView message) {
if (status != wgpu::MapAsyncStatus::Success) {
GGML_LOG_ERROR("ggml_webgpu: Failed to map error buffer: %s\n", message.data);
} else {
const uint32_t * error_data = (const uint32_t *) error_bufs.host_buf.GetConstMappedRange();
if (*error_data) {
GGML_ABORT("ggml_webgpu: SET_ROWS index > 2^32, unsupported.");
}
// We can't unmap in here due to WebGPU reentrancy limitations.
ctx->set_rows_error_buf_pool.free_bufs({ error_bufs });
}
});
ctx->callback_futures.push_back({ f });
}
}
static void ggml_backend_webgpu_map_buffer(webgpu_context & ctx,
@ -283,13 +326,34 @@ static void ggml_backend_webgpu_map_buffer(webgpu_context & ctx,
UINT64_MAX);
}
#ifdef GGML_WEBGPU_DEBUG
// This function adds debugging information to shaders, as WebGPU does not support printing directly.
// To use, add a bind group entry to the setup for the shader you are debugging, add the buffer and
// debug statements in the shader, and then call this function after encoding the commands and submitting them.
static void ggml_backend_webgpu_debug(webgpu_context & ctx) {
wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder();
encoder.CopyBufferToBuffer(ctx->debug_dev_buf, 0, ctx->debug_host_buf, 0, ctx->debug_host_buf.GetSize());
wgpu::CommandBuffer commands = encoder.Finish();
ctx->queue.Submit(1, &commands);
ggml_backend_webgpu_map_buffer(ctx, ctx->debug_host_buf, wgpu::MapMode::Read, 0, ctx->debug_host_buf.GetSize());
const uint32_t * debug_data = (const uint32_t *) ctx->debug_host_buf.GetConstMappedRange();
std::cout << "debug data:";
for (size_t i = 0; i < WEBGPU_DEBUG_BUF_ELEMS; i++) {
std::cout << " " << i << ": " << debug_data[i];
}
std::cout << "\n";
ctx->debug_host_buf.Unmap();
}
#endif
static void ggml_backend_webgpu_build_and_enqueue(webgpu_context & ctx,
wgpu::ComputePipeline & pipeline,
std::vector<uint32_t> params,
std::vector<wgpu::BindGroupEntry> bind_group_entries,
uint32_t wg_x,
bool submit_and_wait = false) {
webgpu_param_bufs params_bufs = ctx->param_buf_pool.alloc_bufs();
webgpu_pool_bufs params_bufs = ctx->param_buf_pool.alloc_bufs();
ggml_backend_webgpu_map_buffer(ctx, params_bufs.host_buf, wgpu::MapMode::Write, 0, params_bufs.host_buf.GetSize());
uint32_t * _params = (uint32_t *) params_bufs.host_buf.GetMappedRange();
@ -429,6 +493,76 @@ static void ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->cpy_pipeline, params, entries, wg_x);
}
static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * idx, ggml_tensor * dst) {
// For set rows specifically, we need to check if src and idx are empty tensors.
if (ggml_is_empty(src) || ggml_is_empty(idx)) {
return;
}
webgpu_pool_bufs error_bufs = ctx->set_rows_error_buf_pool.alloc_bufs();
if (error_bufs.host_buf.GetMapState() == wgpu::BufferMapState::Mapped) {
error_bufs.host_buf.Unmap();
}
size_t src_offset = ggml_backend_webgpu_tensor_offset(src);
// assumes power of 2 offset alignment
size_t src_misalignment = src_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1);
// align to minimum offset alignment
src_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1);
size_t idx_offset = ggml_backend_webgpu_tensor_offset(idx);
size_t idx_misalignment = idx_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1);
idx_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1);
size_t dst_offset = ggml_backend_webgpu_tensor_offset(dst);
size_t dst_misalignment = dst_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1);
dst_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1);
std::vector<uint32_t> params = { (uint32_t) (src_misalignment / ggml_type_size(src->type)),
(uint32_t) (idx_misalignment / ggml_type_size(idx->type)),
(uint32_t) (dst_misalignment / ggml_type_size(dst->type)),
// Convert byte-strides to element-strides
(uint32_t) (src->nb[1] / ggml_type_size(src->type)),
(uint32_t) (src->nb[2] / ggml_type_size(src->type)),
(uint32_t) (src->nb[3] / ggml_type_size(src->type)),
(uint32_t) (idx->nb[0] / ggml_type_size(idx->type)),
(uint32_t) (idx->nb[1] / ggml_type_size(idx->type)),
(uint32_t) (idx->nb[2] / ggml_type_size(idx->type)),
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
// Shape of src
(uint32_t) src->ne[0],
(uint32_t) src->ne[1],
(uint32_t) src->ne[2],
(uint32_t) src->ne[3],
// Shape of idx
(uint32_t) (idx->ne[1]),
(uint32_t) (idx->ne[2]) };
std::vector<wgpu::BindGroupEntry> entries = {
{ .binding = 0,
.buffer = ggml_backend_webgpu_tensor_buf(src),
.offset = ggml_backend_webgpu_tensor_offset(src),
.size = ggml_nbytes(src) },
{ .binding = 1,
.buffer = ggml_backend_webgpu_tensor_buf(idx),
.offset = ggml_backend_webgpu_tensor_offset(idx),
.size = ggml_nbytes(idx) },
{ .binding = 2,
.buffer = ggml_backend_webgpu_tensor_buf(dst),
.offset = ggml_backend_webgpu_tensor_offset(dst),
.size = ggml_nbytes(dst) },
{ .binding = 3, .buffer = error_bufs.dev_buf, .offset = 0, .size = error_bufs.dev_buf.GetSize() }
};
size_t max_wg_size = ctx->limits.maxComputeWorkgroupSizeX;
uint32_t wg_x = (src->ne[1] * src->ne[2] * src->ne[3] + max_wg_size - 1) / max_wg_size;
std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
ctx->staged_set_row_error_bufs.push_back(error_bufs);
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->set_rows_pipeline, params, entries, wg_x);
}
static void ggml_webgpu_mul_mat(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) {
std::vector<uint32_t> params = {
(uint32_t) dst->ne[1], // number of rows in result (M)
@ -487,6 +621,11 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
ggml_webgpu_cpy(ctx, src0, node);
break;
}
case GGML_OP_SET_ROWS:
{
ggml_webgpu_set_rows(ctx, src0, src1, node);
break;
}
case GGML_OP_MUL_MAT:
{
ggml_webgpu_mul_mat(ctx, src0, src1, node);
@ -771,6 +910,14 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline, wgsl_mul_mat, "mul_mat");
}
static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants(1);
constants[0].key = "wg_size";
constants[0].value = webgpu_ctx->limits.maxComputeWorkgroupSizeX;
ggml_webgpu_create_pipeline(
webgpu_ctx->device, webgpu_ctx->set_rows_pipeline, wgsl_set_rows, "set_rows", constants);
}
static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants(1);
constants[0].key = "wg_size";
@ -827,11 +974,35 @@ static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, co
webgpu_ctx->queue = webgpu_ctx->device.GetQueue();
// Create buffer pool for shader parameters
webgpu_ctx->param_buf_pool.init(webgpu_ctx->device);
webgpu_ctx->param_buf_pool.init(webgpu_ctx->device,
WEBGPU_NUM_PARAM_BUFS,
WEBGPU_PARAMS_BUF_SIZE_BYTES,
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform,
wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite);
webgpu_ctx->set_rows_error_buf_pool.init(webgpu_ctx->device,
WEBGPU_NUM_SET_ROWS_ERROR_BUFS,
WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage,
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead);
ggml_webgpu_init_memset_pipeline(webgpu_ctx);
ggml_webgpu_init_mul_mat_pipeline(webgpu_ctx);
ggml_webgpu_init_set_rows_pipeline(webgpu_ctx);
ggml_webgpu_init_cpy_pipeline(webgpu_ctx);
#ifdef GGML_WEBGPU_DEBUG
// Initialize debug buffers
ggml_webgpu_create_buffer(webgpu_ctx->device,
webgpu_ctx->debug_host_buf,
WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t),
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead,
"debug_host_buf");
ggml_webgpu_create_buffer(webgpu_ctx->device,
webgpu_ctx->debug_dev_buf,
WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t),
wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc,
"debug_dev_buf");
#endif
webgpu_ctx->device_init = true;
}
@ -882,7 +1053,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
case GGML_OP_VIEW:
case GGML_OP_PERMUTE:
return true;
case GGML_OP_CPY:
case GGML_OP_CPY | GGML_OP_SET_ROWS:
return op->type == GGML_TYPE_F16 && op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_MUL_MAT:
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;

View file

@ -0,0 +1,82 @@
enable f16;
@group(0) @binding(0)
var<storage, read_write> src: array<f32>;
@group(0) @binding(1)
var<storage, read_write> idx: array<u32>;
@group(0) @binding(2)
var<storage, read_write> dst: array<f16>;
@group(0) @binding(3)
var<storage, read_write> error: atomic<u32>;
struct Params {
offset_src: u32, // in elements
offset_idx: u32, // in elements
offset_dst: u32, // in elements
// Strides (in elements)
stride_src1: u32,
stride_src2: u32,
stride_src3: u32,
stride_idx0: u32,
stride_idx1: u32,
stride_idx2: u32,
stride_dst1: u32,
stride_dst2: u32,
stride_dst3: u32,
// Shape of src
ne0: u32,
n_rows: u32,
ne2: u32,
ne3: u32,
// Shape of idx
idx1: u32,
idx2: u32,
};
@group(0) @binding(4)
var<uniform> params: Params;
override wg_size: u32;
@compute @workgroup_size(wg_size)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x >= params.n_rows * params.ne2 * params.ne3) {
return;
}
var i = gid.x;
let i_src3 = i / (params.ne2 * params.n_rows);
let i_dst3 = i / (params.ne2 * 3);
i = i % (params.ne2 * params.n_rows);
let i_src2 = i / params.n_rows;
let i_src1 = i % params.n_rows;
let i_idx2 = i_src3 % params.idx2;
let i_idx1 = i_src2 % params.idx1;
let i_idx0 = i_src1;
let idx_high = (params.offset_idx + i_idx0 * params.stride_idx0 + i_idx1 * params.stride_idx1 + i_idx2 * params.stride_idx2) * 2;
let idx_high_val = idx[idx_high];
let idx_low_val = idx[idx_high + 1];
if (idx_low_val != 0) {
// Upper bits of index are not zero, output will be incorrect
atomicStore(&error, 1);
return;
}
let i_dst_row = params.offset_dst + idx_high_val * params.stride_dst1 + i_src2 * params.stride_dst2 + i_src3 * params.stride_dst3;
let i_src_row = params.offset_src + i_src1 * params.stride_src1 + i_src2 * params.stride_src2 + i_src3 * params.stride_src3;
for (var i: u32 = 0; i < params.ne0; i++) {
dst[i_dst_row + i] = f16(src[i_src_row + i]);
}
}

View file

@ -63,7 +63,7 @@ dry_seq_break_max = 128
extra_images_max = 4
# global vars
KcppVersion = "1.97"
KcppVersion = "1.97.1"
showdebug = True
kcpp_instance = None #global running instance
global_memory = {"tunnel_url": "", "restart_target":"", "input_to_exit":False, "load_complete":False, "restart_override_config_target":""}

View file

@ -0,0 +1,59 @@
{# Alias tools -> available_tools #}
{%- if tools and not available_tools -%}
{%- set available_tools = tools -%}
{%- endif -%}
{%- if messages[0]['role'] == 'system' %}
{%- set system_message = messages[0]['content'] %}
{%- set loop_messages = messages[1:] %}
{%- else %}
{%- set system_message = "Knowledge Cutoff Date: April 2024. Today's Date: " + strftime_now('%B %d, %Y') + ". You are Granite, developed by IBM." %}
{%- if available_tools and documents %}
{%- set system_message = system_message + " You are a helpful assistant with access to the following tools. When a tool is required to answer the user's query, respond only with <|tool_call|> followed by a JSON list of tools used. If a tool does not exist in the provided list of tools, notify the user that you do not have the ability to fulfill the request. Write the response to the user's input by strictly aligning with the facts in the provided documents. If the information needed to answer the question is not available in the documents, inform the user that the question cannot be answered based on the available data." %}
{%- elif available_tools %}
{%- set system_message = system_message + " You are a helpful assistant with access to the following tools. When a tool is required to answer the user's query, respond only with <|tool_call|> followed by a JSON list of tools used. If a tool does not exist in the provided list of tools, notify the user that you do not have the ability to fulfill the request." %}
{%- elif documents %}
{%- set system_message = system_message + " Write the response to the user's input by strictly aligning with the facts in the provided documents. If the information needed to answer the question is not available in the documents, inform the user that the question cannot be answered based on the available data." %}
{%- elif thinking %}
{%- set system_message = system_message + " You are a helpful AI assistant.
Respond to every user query in a comprehensive and detailed way. You can write down your thoughts and reasoning process before responding. In the thought process, engage in a comprehensive cycle of analysis, summarization, exploration, reassessment, reflection, backtracing, and iteration to develop well-considered thinking process. In the response section, based on various attempts, explorations, and reflections from the thoughts section, systematically present the final solution that you deem correct. The response should summarize the thought process. Write your thoughts between <think></think> and write your response between <response></response> for each user query." %}
{%- else %}
{%- set system_message = system_message + " You are a helpful AI assistant." %}
{%- endif %}
{%- if 'citations' in controls and documents %}
{%- set system_message = system_message + '
Use the symbols <|start_of_cite|> and <|end_of_cite|> to indicate when a fact comes from a document in the search result, e.g <|start_of_cite|> {document_id: 1}my fact <|end_of_cite|> for a fact from document 1. Afterwards, list all the citations with their corresponding documents in an ordered list.' %}
{%- endif %}
{%- if 'hallucinations' in controls and documents %}
{%- set system_message = system_message + '
Finally, after the response is written, include a numbered list of sentences from the response with a corresponding risk value that are hallucinated and not based in the documents.' %}
{%- endif %}
{%- set loop_messages = messages %}
{%- endif %}
{{- '<|start_of_role|>system<|end_of_role|>' + system_message + '<|end_of_text|>
' }}
{%- if available_tools %}
{{- '<|start_of_role|>available_tools<|end_of_role|>' }}
{{- available_tools | tojson(indent=4) }}
{{- '<|end_of_text|>
' }}
{%- endif %}
{%- if documents %}
{%- for document in documents %}
{{- '<|start_of_role|>document {"document_id": "' + document['doc_id'] | string + '"}<|end_of_role|>
' }}
{{- document['text'] }}
{{- '<|end_of_text|>
' }}
{%- endfor %}
{%- endif %}
{%- for message in loop_messages %}
{{- '<|start_of_role|>' + message['role'] + '<|end_of_role|>' + message['content'] + '<|end_of_text|>
' }}
{%- if loop.last and add_generation_prompt %}
{{- '<|start_of_role|>assistant' }}
{%- if controls %}
{{- ' ' + controls | tojson()}}
{%- endif %}
{{- '<|end_of_role|>' }}
{%- endif %}
{%- endfor %}

View file

@ -193,11 +193,11 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
return LLM_CHAT_TEMPLATE_LLAMA4;
} else if (tmpl_contains("<|endofuserprompt|>")) {
return LLM_CHAT_TEMPLATE_DOTS1;
} else if (tmpl_contains("<|startoftext|>") && tmpl_contains("<|extra_4|>")) {
} else if (tmpl_contains("<|extra_0|>") && tmpl_contains("<|extra_4|>")) {
return LLM_CHAT_TEMPLATE_HUNYUAN_MOE;
} else if (tmpl_contains("<|start|>") && tmpl_contains("<|channel|>")) {
return LLM_CHAT_TEMPLATE_OPENAI_MOE;
} else if (tmpl_contains("<hy_place▁holder▁no▁2>") && tmpl_contains("<hy_place▁holder▁no▁3>")) {
} else if (tmpl_contains("<hy_Assistant>") && tmpl_contains("<hy_place▁holder▁no▁3>")) {
return LLM_CHAT_TEMPLATE_HUNYUAN_DENSE;
} else if (tmpl_contains("<|im_assistant|>assistant<|im_middle|>")) {
return LLM_CHAT_TEMPLATE_KIMI_K2;
@ -625,8 +625,6 @@ int32_t llm_chat_apply_template(
} else if (tmpl == LLM_CHAT_TEMPLATE_YANDEX) {
// Yandex template ("\n\n" is defined as EOT token)
ss << "<s>";
for (size_t i = 0; i < chat.size(); i++) {
std::string role(chat[i]->role);
if (role == "user") {

View file

@ -1,5 +1,5 @@
-r ../../requirements/requirements-convert_legacy_llama.txt
--extra-index-url https://download.pytorch.org/whl/cpu
pillow~=11.3.0
torch~=2.2.1
torchvision~=0.17.1
torch~=2.4.0
torchvision~=0.19.1