Merge branch 'upstream' into concedo_experimental

# Conflicts:
#	.github/workflows/build.yml
#	README.md
#	docs/backend/SYCL.md
#	ggml/src/ggml-sycl/CMakeLists.txt
#	ggml/src/ggml-vulkan/CMakeLists.txt
#	ggml/src/ggml-vulkan/ggml-vulkan.cpp
#	scripts/sync-ggml.last
#	tests/test-chat-template.cpp
This commit is contained in:
Concedo 2025-04-01 20:16:07 +08:00
commit 9e182b3e78
44 changed files with 2395 additions and 831 deletions

View file

@ -708,6 +708,12 @@ class Model:
if chkhsh == "7dec86086fcc38b66b7bc1575a160ae21cf705be7718b9d5598190d7c12db76f": if chkhsh == "7dec86086fcc38b66b7bc1575a160ae21cf705be7718b9d5598190d7c12db76f":
# ref: https://huggingface.co/UW/OLMo2-8B-SuperBPE-t180k # ref: https://huggingface.co/UW/OLMo2-8B-SuperBPE-t180k
res = "superbpe" res = "superbpe"
if chkhsh == "1994ffd01900cfb37395608534236ecd63f2bd5995d6cb1004dda1af50240f15":
# ref: https://huggingface.co/trillionlabs/Trillion-7B-preview
res = "trillion"
if chkhsh == "96a5f08be6259352137b512d4157e333e21df7edd3fcd152990608735a65b224":
# ref: https://huggingface.co/inclusionAI/Ling-lite
res = "bailingmoe"
if res is None: if res is None:
logger.warning("\n") logger.warning("\n")
@ -3551,8 +3557,8 @@ class RWKV6Qwen2Model(Rwkv6Model):
head_size = hidden_size // num_attention_heads head_size = hidden_size // num_attention_heads
rms_norm_eps = self.hparams["rms_norm_eps"] rms_norm_eps = self.hparams["rms_norm_eps"]
intermediate_size = self.hparams["intermediate_size"] intermediate_size = self.hparams["intermediate_size"]
time_mix_extra_dim = 64 if hidden_size >= 4096 else 32 time_mix_extra_dim = self.hparams.get("lora_rank_tokenshift", 64 if hidden_size >= 4096 else 32)
time_decay_extra_dim = 128 if hidden_size >= 4096 else 64 time_decay_extra_dim = self.hparams.get("lora_rank_decay", 128 if hidden_size >= 4096 else 64)
# RWKV isn't context limited # RWKV isn't context limited
self.gguf_writer.add_context_length(1048576) self.gguf_writer.add_context_length(1048576)
@ -5130,6 +5136,108 @@ class GraniteMoeModel(GraniteModel):
return super().modify_tensors(data_torch, name, bid) return super().modify_tensors(data_torch, name, bid)
@Model.register("BailingMoeForCausalLM")
class BailingMoeModel(Model):
model_arch = gguf.MODEL_ARCH.BAILINGMOE
def set_vocab(self):
self._set_vocab_gpt2()
def set_gguf_parameters(self):
super().set_gguf_parameters()
hparams = self.hparams
if hparams.get("head_dim"):
rope_dim = hparams["head_dim"]
else:
rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"]
self.gguf_writer.add_rope_dimension_count(rope_dim)
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"])
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"])
self.gguf_writer.add_expert_weights_scale(1.0)
self.gguf_writer.add_expert_count(hparams["num_experts"])
self.gguf_writer.add_expert_shared_count(hparams["num_shared_experts"])
self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"])
_experts: list[dict[str, Tensor]] | None = None
@staticmethod
def permute(weights: Tensor, n_head: int, n_head_kv: int | None):
if n_head_kv is not None and n_head != n_head_kv:
n_head = n_head_kv
return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
.swapaxes(1, 2)
.reshape(weights.shape))
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
n_head = self.hparams["num_attention_heads"]
n_kv_head = self.hparams.get("num_key_value_heads")
n_embd = self.hparams["hidden_size"]
head_dim = self.hparams.get("head_dim", n_embd // n_head)
output_name = self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT)
if name.endswith("attention.dense.weight"):
return [(self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_OUT, bid), data_torch)]
elif name.endswith("query_key_value.weight"):
q, k, v = data_torch.split([n_head * head_dim, n_kv_head * head_dim, n_kv_head * head_dim], dim=-2)
return [
(self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q, bid), BailingMoeModel.permute(q, n_head, n_head)),
(self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid), BailingMoeModel.permute(k, n_head, n_kv_head)),
(self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_V, bid), v)
]
elif name.find("mlp.experts") != -1:
n_experts = self.hparams["num_experts"]
assert bid is not None
tensors: list[tuple[str, Tensor]] = []
if self._experts is None:
self._experts = [{} for _ in range(self.block_count)]
self._experts[bid][name] = data_torch
if len(self._experts[bid]) >= n_experts * 3:
# merge the experts into a single 3d tensor
for w_name in ["down_proj", "gate_proj", "up_proj"]:
datas: list[Tensor] = []
for xid in range(n_experts):
ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
datas.append(self._experts[bid][ename])
del self._experts[bid][ename]
data_torch = torch.stack(datas, dim=0)
merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
new_name = self.map_tensor_name(merged_name)
tensors.append((new_name, data_torch))
return tensors
new_name = self.map_tensor_name(name)
if new_name == output_name and self.hparams.get("norm_head"):
data_torch = data_torch.float()
data_torch /= torch.norm(data_torch, p=2, dim=0, keepdim=True) + 1e-7
return [(new_name, data_torch)]
def prepare_tensors(self):
super().prepare_tensors()
if self._experts is not None:
# flatten `list[dict[str, Tensor]]` into `list[str]`
experts = [k for d in self._experts for k in d.keys()]
if len(experts) > 0:
raise ValueError(f"Unprocessed experts: {experts}")
@Model.register("ChameleonForConditionalGeneration") @Model.register("ChameleonForConditionalGeneration")
@Model.register("ChameleonForCausalLM") # obsolete @Model.register("ChameleonForCausalLM") # obsolete
class ChameleonModel(Model): class ChameleonModel(Model):

View file

@ -111,6 +111,8 @@ models = [
{"name": "deepseek-r1-qwen", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"}, {"name": "deepseek-r1-qwen", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"},
{"name": "gpt-4o", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Xenova/gpt-4o", }, {"name": "gpt-4o", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Xenova/gpt-4o", },
{"name": "superbpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/UW/OLMo2-8B-SuperBPE-t180k", }, {"name": "superbpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/UW/OLMo2-8B-SuperBPE-t180k", },
{"name": "trillion", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/trillionlabs/Trillion-7B-preview", },
{"name": "bailingmoe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inclusionAI/Ling-lite", },
] ]

View file

@ -1517,14 +1517,16 @@ struct clip_ctx * clip_init(const char * fname, struct clip_context_params ctx_p
const int n_kv = gguf_get_n_kv(ctx); const int n_kv = gguf_get_n_kv(ctx);
const int ftype = get_u32(ctx, KEY_FTYPE); const int ftype = get_u32(ctx, KEY_FTYPE);
const std::string ftype_str = get_ftype(ftype); const std::string ftype_str = get_ftype(ftype);
const int idx_desc = get_key_idx(ctx, KEY_DESCRIPTION);
const std::string description = gguf_get_val_str(ctx, idx_desc);
const int idx_name = gguf_find_key(ctx, KEY_NAME); const int idx_name = gguf_find_key(ctx, KEY_NAME);
if (idx_name != -1) { // make name optional temporarily as some of the uploaded models missing it due to a bug if (idx_name != -1) { // make name optional temporarily as some of the uploaded models missing it due to a bug
const std::string name = gguf_get_val_str(ctx, idx_name); const std::string name = gguf_get_val_str(ctx, idx_name);
LOG_INF("%s: model name: %s\n", __func__, name.c_str()); LOG_INF("%s: model name: %s\n", __func__, name.c_str());
} }
LOG_INF("%s: description: %s\n", __func__, description.c_str()); const int idx_desc = gguf_find_key(ctx, KEY_DESCRIPTION);
if (idx_desc != -1) { // ditto
const std::string description = gguf_get_val_str(ctx, idx_desc);
LOG_INF("%s: description: %s\n", __func__, description.c_str());
}
LOG_INF("%s: GGUF version: %d\n", __func__, gguf_get_version(ctx)); LOG_INF("%s: GGUF version: %d\n", __func__, gguf_get_version(ctx));
LOG_INF("%s: alignment: %zu\n", __func__, gguf_get_alignment(ctx)); LOG_INF("%s: alignment: %zu\n", __func__, gguf_get_alignment(ctx));
LOG_INF("%s: n_tensors: %d\n", __func__, n_tensors); LOG_INF("%s: n_tensors: %d\n", __func__, n_tensors);

View file

@ -699,11 +699,13 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
const std::string voice_data = audio_data; const std::string voice_data = audio_data;
auto tmp = common_tokenize(vocab, voice_data, false, true); auto tmp = common_tokenize(vocab, voice_data, false, true);
printf("\n\n");
std::ostringstream tokens_oss;
for (size_t i = 0; i < tmp.size(); ++i) { for (size_t i = 0; i < tmp.size(); ++i) {
printf("%d, ", tmp[i]); tokens_oss << tmp[i] << ", ";
} }
printf("\n\n"); LOG_INF("\n\n%s: llama tokens: %s\n\n", __func__, tokens_oss.str().c_str());
prompt_add(prompt_inp, tmp); prompt_add(prompt_inp, tmp);
#else #else
prompt_add(prompt_inp, llama_tokens { prompt_add(prompt_inp, llama_tokens {

View file

@ -33,6 +33,8 @@ bool g_mul_mat_q = true;
#include "ggml-cuda/rope.cuh" #include "ggml-cuda/rope.cuh"
#include "ggml-cuda/scale.cuh" #include "ggml-cuda/scale.cuh"
#include "ggml-cuda/softmax.cuh" #include "ggml-cuda/softmax.cuh"
#include "ggml-cuda/ssm-conv.cuh"
#include "ggml-cuda/ssm-scan.cuh"
#include "ggml-cuda/sum.cuh" #include "ggml-cuda/sum.cuh"
#include "ggml-cuda/sumrows.cuh" #include "ggml-cuda/sumrows.cuh"
#include "ggml-cuda/tsembd.cuh" #include "ggml-cuda/tsembd.cuh"
@ -2301,6 +2303,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_SUM_ROWS: case GGML_OP_SUM_ROWS:
ggml_cuda_op_sum_rows(ctx, dst); ggml_cuda_op_sum_rows(ctx, dst);
break; break;
case GGML_OP_SSM_CONV:
ggml_cuda_op_ssm_conv(ctx, dst);
break;
case GGML_OP_SSM_SCAN:
ggml_cuda_op_ssm_scan(ctx, dst);
break;
case GGML_OP_ARGSORT: case GGML_OP_ARGSORT:
ggml_cuda_op_argsort(ctx, dst); ggml_cuda_op_argsort(ctx, dst);
break; break;
@ -3198,6 +3206,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_COS: case GGML_OP_COS:
case GGML_OP_CLAMP: case GGML_OP_CLAMP:
case GGML_OP_LOG: case GGML_OP_LOG:
case GGML_OP_SSM_SCAN:
case GGML_OP_SSM_CONV:
return true; return true;
case GGML_OP_CONT: case GGML_OP_CONT:
return op->src[0]->type != GGML_TYPE_BF16; return op->src[0]->type != GGML_TYPE_BF16;

View file

@ -0,0 +1,151 @@
#include "ssm-conv.cuh"
template <size_t split_d_inner, size_t d_conv>
static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float * __restrict__ src1,
const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1,
float * __restrict__ dst, const int dst_nb0, const int dst_nb1, const int dst_nb2,
const int nc, const int ncs, const int nr, const int n_t, const int n_s) {
const int tid = threadIdx.x;
const int bidx = blockIdx.x;
const int bidy = blockIdx.y;
const float * x_block = (const float *) ((char *) src0 + bidx * src0_nb2 + bidy * split_d_inner * src0_nb1);
const float * w_block = (const float *) ((char *) src1 + bidy * split_d_inner * src1_nb1);
float * y_block = (float *) ((char *) dst + bidx * dst_nb2 + bidy * split_d_inner * dst_nb0);
const int stride_x = src0_nb1 / sizeof(float);
const int stride_w = src1_nb1 / sizeof(float);
const int stride_y = dst_nb1 / sizeof(float);
float x[d_conv] = { 0.0f };
float w[d_conv] = { 0.0f };
#pragma unroll
for (int j = 0; j < d_conv; j++) {
w[j] = w_block[tid * stride_w + j];
}
for (int i = 0; i < n_t; i++) {
float sumf = 0.0f;
if (i == 0) {
for (int j = 0; j < d_conv; j++) {
x[j] = x_block[tid * stride_x + j];
}
} else {
x[(i - 1) % d_conv] = x_block[tid * stride_x + i + d_conv - 1];
}
#pragma unroll
for (int j = 0; j < d_conv; j++) {
sumf += x[(i + j) % d_conv] * w[j];
}
y_block[i * stride_y + tid] = sumf;
}
}
template <size_t split_d_inner, size_t d_conv, size_t split_n_t>
static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0, const float * __restrict__ src1,
const int src0_nb0, const int src0_nb1, const int src0_nb2,
const int src1_nb1, float * __restrict__ dst, const int dst_nb0,
const int dst_nb1, const int dst_nb2, const int nc, const int ncs,
const int nr, const int n_t, const int n_s) {
const int tid = threadIdx.x;
const int bidx = blockIdx.x;
const int bidy = blockIdx.y;
const int bidz = blockIdx.z;
const float * x_block = (const float *) ((char *) src0 + bidx * src0_nb2 + bidy * split_d_inner * src0_nb1 +
bidz * split_n_t * src0_nb0);
const float * w_block = (const float *) ((char *) src1 + bidy * split_d_inner * src1_nb1);
float * y_block =
(float *) ((char *) dst + bidx * dst_nb2 + bidz * split_n_t * dst_nb1 + bidy * split_d_inner * dst_nb0);
const int stride_x = src0_nb1 / sizeof(float);
const int stride_w = src1_nb1 / sizeof(float);
const int stride_y = dst_nb1 / sizeof(float);
float x[d_conv] = { 0.0f };
float w[d_conv] = { 0.0f };
#pragma unroll
for (int j = 0; j < d_conv; j++) {
w[j] = w_block[tid * stride_w + j];
}
#pragma unroll
for (int i = 0; i < split_n_t; i++) {
if (bidz * split_n_t + i < n_t) {
float sumf = 0.0f;
if (i == 0) {
for (int j = 0; j < d_conv; j++) {
x[j] = x_block[tid * stride_x + j];
}
} else {
x[(i - 1) % d_conv] = x_block[tid * stride_x + i + d_conv - 1];
}
#pragma unroll
for (int j = 0; j < d_conv; j++) {
sumf += x[(i + j) % d_conv] * w[j];
}
y_block[i * stride_y + tid] = sumf;
}
}
}
static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int src0_nb0, const int src0_nb1,
const int src0_nb2, const int src1_nb1, float * dst, const int dst_nb0, const int dst_nb1,
const int dst_nb2, const int nc, const int ncs, const int nr, const int n_t,
const int n_s, cudaStream_t stream) {
const int threads = 128;
GGML_ASSERT(nr % threads == 0);
if (n_t <= 32) {
const dim3 blocks(n_s, (nr + threads - 1) / threads, 1);
if (nc == 4) {
ssm_conv_f32<threads, 4><<<blocks, threads, 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
dst, dst_nb0, dst_nb1, dst_nb2, nc, ncs, nr, n_t,
n_s);
} else {
GGML_ABORT("Only support kernel size = 4 now.");
}
} else {
if (nc == 4) {
const int split_n_t = 32;
dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t);
ssm_conv_long_token_f32<threads, 4, split_n_t>
<<<blocks, threads, 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0,
dst_nb1, dst_nb2, nc, ncs, nr, n_t, n_s);
} else {
GGML_ABORT("Only support kernel size = 4 right now.");
}
}
}
void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0]; // conv_x
const struct ggml_tensor * src1 = dst->src[1]; // conv1d.weight
const int nc = src1->ne[0]; // d_conv
const int ncs = src0->ne[0]; // d_conv - 1 + n_t
const int nr = src0->ne[1]; // d_inner
const int n_t = dst->ne[1]; // tokens per sequence
const int n_s = dst->ne[2]; // number of sequences in the batch
GGML_ASSERT(dst->ne[0] == nr);
GGML_ASSERT(src0->nb[0] == sizeof(float));
GGML_ASSERT(src1->nb[0] == sizeof(float));
GGML_ASSERT(src0->nb[1] == src0->ne[0] * sizeof(float));
const float * src0_d = (const float *) src0->data;
const float * src1_d = (const float *) src1->data;
float * dst_d = (float *) dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
ssm_conv_f32_cuda(src0_d, src1_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, dst->nb[0], dst->nb[1],
dst->nb[2], nc, ncs, nr, n_t, n_s, stream);
}

View file

@ -0,0 +1,3 @@
#include "common.cuh"
void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View file

@ -0,0 +1,155 @@
#include "ssm-scan.cuh"
// #include <cuda_runtime.h>
// static __device__ void global_to_shared(const float *src, float *dst) {
// asm volatile("cp.async.");
// }
template <size_t splitD, size_t N>
__global__ void __launch_bounds__(splitD, 2)
ssm_scan_f32(const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2,
const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5,
const int src0_nb1, const int src0_nb2, const int src1_nb0, const int src1_nb1, const int src1_nb2,
const int src1_nb3, const int src2_nb0, const int src2_nb1, const int src2_nb2, const int src3_nb1,
const int src4_nb1, const int src4_nb2, const int src5_nb1, const int src5_nb2,
float * __restrict__ dst, const int D, const int L, const int B) {
const int bidx = blockIdx.x; // split along B
const int bidy = blockIdx.y; // split along D
const int tid = threadIdx.x;
const int wid = tid / 32;
const int wtid = tid % 32;
extern __shared__ float smem[];
const int stride_sA = N + 1;
const int stride_ss0 = N + 1;
float * smem_A = smem;
float * smem_s0 = smem_A + splitD * stride_sA;
const float * s0_block = (const float *) ((char *) src0 + bidx * src0_nb2 + bidy * splitD * src0_nb1);
const float * x_block = (const float *) ((char *) src1 + (bidx * src1_nb2) + bidy * splitD * sizeof(float));
const float * dt_block = (const float *) ((char *) src2 + (bidx * src2_nb2) + bidy * splitD * sizeof(float));
const float * A_block = (const float *) ((char *) src3 + bidy * splitD * src3_nb1);
const float * B_block = (const float *) ((char *) src4 + (bidx * src4_nb2));
const float * C_block = (const float *) ((char *) src5 + (bidx * src5_nb2));
float * y_block = (float *) ((char *) dst + (bidx * src1_nb2) + bidy * splitD * sizeof(float));
float * s_block = (float *) ((char *) dst + src1_nb3 + bidx * src0_nb2 + bidy * splitD * src0_nb1);
const int stride_s0 = src0_nb1 / sizeof(float);
const int stride_x = src1_nb1 / sizeof(float);
const int stride_dt = src2_nb1 / sizeof(float);
const int stride_A = src3_nb1 / sizeof(float);
const int stride_B = src4_nb1 / sizeof(float);
const int stride_C = src5_nb1 / sizeof(float);
const int stride_s = stride_s0;
const int stride_y = stride_x;
// can N not be 16? for example 32?
if (N == 16) {
#pragma unroll
for (int i = 0; i < splitD / 4; i += 2) {
float value = A_block[(wid * warpSize + i) * stride_A + wtid];
// todo: bank conflict
// I am always confused with how to use the swizzling method to solve
// bank conflit. Hoping somebody can tell me.
smem_A[(wid * warpSize + i) * stride_sA + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
}
#pragma unroll
for (int i = 0; i < splitD / 4; i += 2) {
float value = s0_block[(wid * warpSize + i) * stride_s0 + wtid];
smem_s0[(wid * warpSize + i) * stride_ss0 + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
}
}
__syncthreads();
for (int i = 0; i < L; i++) {
float dt_soft_plus = dt_block[i * stride_dt + tid];
if (dt_soft_plus <= 20.0f) {
dt_soft_plus = log1pf(exp(dt_soft_plus));
}
float x_dt = x_block[i * stride_x + tid] * dt_soft_plus;
float sumf = 0.0f;
#pragma unroll
for (int j = 0; j < N; j++) {
float state = (smem_s0[tid * stride_ss0 + j] * expf(dt_soft_plus * smem_A[tid * stride_sA + j])) +
(B_block[i * stride_B + j] * x_dt);
sumf += state * C_block[i * stride_C + j];
if (i == L - 1) {
s_block[tid * stride_s + j] = state;
} else {
smem_s0[tid * stride_ss0 + j] = state;
}
}
__syncthreads();
y_block[i * stride_y + tid] = sumf;
}
}
static void ssm_scan_f32_cuda(const float * src0, const float * src1, const float * src2, const float * src3,
const float * src4, const float * src5, const int src0_nb1, const int src0_nb2,
const int src1_nb0, const int src1_nb1, const int src1_nb2, const int src1_nb3,
const int src2_nb0, const int src2_nb1, const int src2_nb2, const int src3_nb1,
const int src4_nb1, const int src4_nb2, const int src5_nb1, const int src5_nb2,
float * dst, const int N, const int D, const int L, const int B, cudaStream_t stream) {
const int threads = 128;
// todo: consider D cannot be divided,does this situation exist?
GGML_ASSERT(D % threads == 0);
const dim3 blocks(B, (D + threads - 1) / threads, 1);
const int smem_size = (threads * (N + 1) * 2) * sizeof(float);
if (N == 16) {
ssm_scan_f32<128, 16><<<blocks, threads, smem_size, stream>>>(
src0, src1, src2, src3, src4, src5, src0_nb1, src0_nb2, src1_nb0, src1_nb1, src1_nb2, src1_nb3, src2_nb0,
src2_nb1, src2_nb2, src3_nb1, src4_nb1, src4_nb2, src5_nb1, src5_nb2, dst, D, L, B);
} else {
GGML_ABORT("doesn't support N!=16.");
}
}
void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0]; // s
const struct ggml_tensor * src1 = dst->src[1]; // x
const struct ggml_tensor * src2 = dst->src[2]; // dt
const struct ggml_tensor * src3 = dst->src[3]; // A
const struct ggml_tensor * src4 = dst->src[4]; // B
const struct ggml_tensor * src5 = dst->src[5]; // C
// const int64_t d_state = src0->ne[0];
// const int64_t d_inner = src0->ne[1];
// const int64_t l = src1->ne[1];
// const int64_t b = src0->ne[2];
const int64_t nc = src0->ne[0]; // d_state
const int64_t nr = src0->ne[1]; // d_inner
const int64_t n_t = src1->ne[1]; // number of tokens per sequence
const int64_t n_s = src0->ne[2]; // number of sequences in the batch
GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
GGML_ASSERT(src0->nb[0] == sizeof(float));
GGML_ASSERT(src1->nb[0] == sizeof(float));
GGML_ASSERT(src2->nb[0] == sizeof(float));
GGML_ASSERT(src3->nb[0] == sizeof(float));
GGML_ASSERT(src4->nb[0] == sizeof(float));
GGML_ASSERT(src5->nb[0] == sizeof(float));
// required for the dot product between s and C
GGML_ASSERT(src0->nb[1] == src0->ne[0] * sizeof(float));
// required for per-sequence offsets for states
GGML_ASSERT(src0->nb[2] == src0->ne[0] * src0->ne[1] * sizeof(float));
// required to get correct offset for state destination (i.e. src1->nb[3])
GGML_ASSERT(src1->nb[3] == src1->ne[0] * src1->ne[1] * src1->ne[2] * sizeof(float));
const float * src0_d = (const float *) src0->data;
const float * src1_d = (const float *) src1->data;
const float * src2_d = (const float *) src2->data;
const float * src3_d = (const float *) src3->data;
const float * src4_d = (const float *) src4->data;
const float * src5_d = (const float *) src5->data;
float * dst_d = (float *) dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
ssm_scan_f32_cuda(src0_d, src1_d, src2_d, src3_d, src4_d, src5_d, src0->nb[1], src0->nb[2], src1->nb[0],
src1->nb[1], src1->nb[2], src1->nb[3], src2->nb[0], src2->nb[1], src2->nb[2], src3->nb[1],
src4->nb[1], src4->nb[2], src5->nb[1], src5->nb[2], dst_d, nc, nr, n_t, n_s, stream);
}

View file

@ -0,0 +1,3 @@
#include "common.cuh"
void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View file

@ -3128,14 +3128,15 @@ kernel void kernel_flash_attn_ext(
const int iq2 = tgpig[1]; const int iq2 = tgpig[1];
const int iq1 = tgpig[0]*Q; const int iq1 = tgpig[0]*Q;
const short DK4 = DK/4; constexpr short DK4 = DK/4;
const short DK8 = DK/8; constexpr short DK8 = DK/8;
const short DK16 = DK/16; constexpr short DK16 = DK/16;
const short DV4 = DV/4; constexpr short DV4 = DV/4;
const short DV8 = DV/8; constexpr short DV8 = DV/8;
const short DV16 = DV/16; constexpr short DV16 = DV/16;
const short NW = N_SIMDWIDTH;
const short SH = (2*C + Q); // shared memory per simdgroup (s_t == float) constexpr short NW = N_SIMDWIDTH;
constexpr short SH = (2*C + Q); // shared memory per simdgroup (s_t == float)
const short TS = nsg*SH; // shared memory size per query in (s_t == float) const short TS = nsg*SH; // shared memory size per query in (s_t == float)
const short T = DK + 2*TS; // shared memory size per query in (half) const short T = DK + 2*TS; // shared memory size per query in (half)
@ -3641,11 +3642,11 @@ kernel void kernel_flash_attn_ext_vec(
const int iq2 = tgpig[1]; const int iq2 = tgpig[1];
const int iq1 = tgpig[0]; const int iq1 = tgpig[0];
const short DK4 = DK/4; constexpr short DK4 = DK/4;
const short DV4 = DV/4; constexpr short DV4 = DV/4;
const short NW = N_SIMDWIDTH; constexpr short NW = N_SIMDWIDTH;
const short NL = NW/NE; // note: this can be adjusted to support different head sizes and simdgroup work loads constexpr short NL = NW/NE; // note: this can be adjusted to support different head sizes and simdgroup work loads
const short SH = 2*C; // shared memory per simdgroup constexpr short SH = 2*C; // shared memory per simdgroup
const short T = DK + nsg*SH; // shared memory size per query in (half) const short T = DK + nsg*SH; // shared memory size per query in (half)
@ -3956,7 +3957,7 @@ kernel void kernel_flash_attn_ext_vec(
half, half4, \ half, half4, \
half4 half4
typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 128>) flash_attn_ext_vec_t; typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t;
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>; template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>;
#if defined(GGML_METAL_USE_BF16) #if defined(GGML_METAL_USE_BF16)

View file

@ -66,41 +66,6 @@ int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block
return sycl_down_blk_size; return sycl_down_blk_size;
} }
void ggml_sycl_op_flatten(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
const ggml_tensor *src1, ggml_tensor *dst,
const ggml_sycl_op_flatten_t op) try {
const bool use_src1 = src1 != nullptr;
if(use_src1)
GGML_ASSERT(strcmp(src1->buffer->buft->iface.get_name(src1->buffer->buft), GGML_SYCL_NAME "_Split") != 0);
GGML_ASSERT(strcmp(dst->buffer->buft->iface.get_name(dst->buffer->buft), GGML_SYCL_NAME "_Split") != 0);
// dd = data device
float * src0_ddf = (float *) src0->data;
float * src1_ddf = use_src1 ? (float *) src1->data : nullptr;
float * dst_ddf = (float *) dst->data;
ggml_sycl_pool_alloc<float> src0_f(ctx.pool());
ggml_sycl_pool_alloc<float> src1_f(ctx.pool());
ggml_sycl_pool_alloc<float> dst_f(ctx.pool());
ggml_sycl_set_device(ctx.device);
queue_ptr main_stream = ctx.stream();
// GGML_SYCL_DEBUG("ctx.device=%d, main_stream=%p src0_on_device=%d, src1_on_device=%d, dst_on_device=%d\n",
// ctx.device, main_stream, src0_on_device, src1_on_device, dst_on_device);
// do the computation
op(ctx, src0, src1, dst, src0_ddf, src1_ddf, dst_ddf, main_stream);
// print_ggml_tensor("tensor", dst);
}
catch (sycl::exception const &exc) {
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
<< ", line:" << __LINE__ << std::endl;
std::exit(1);
}
void release_extra_gpu(ggml_tensor_extra_gpu * extra, std::vector<queue_ptr> streams) { void release_extra_gpu(ggml_tensor_extra_gpu * extra, std::vector<queue_ptr> streams) {
for (int i = 0; i < ggml_sycl_info().device_count; ++i) { for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
for (int64_t is = 0; is < GGML_SYCL_MAX_STREAMS; ++is) { for (int64_t is = 0; is < GGML_SYCL_MAX_STREAMS; ++is) {

View file

@ -494,12 +494,6 @@ static __dpct_inline__ Tp* get_pointer(sycl::local_accessor<Tp, dim> acc) {
int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block_size); int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block_size);
typedef void (*ggml_sycl_op_flatten_t)(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
const ggml_tensor *src1,
ggml_tensor *dst, const float *src0_dd,
const float *src1_dd, float *dst_dd,
const queue_ptr &main_stream);
template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t> template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
static void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst, static void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
int ne0, int ne1, int ne2, int ne3, int ne0, int ne1, int ne2, int ne3,
@ -757,24 +751,22 @@ struct bin_bcast_sycl {
template <class op> template <class op>
inline void ggml_sycl_op_bin_bcast(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, inline void ggml_sycl_op_bin_bcast(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
const ggml_tensor *src1, ggml_tensor *dst, const ggml_tensor *src1, ggml_tensor *dst) {
const float *src0_dd, const float *src1_dd, dpct::queue_ptr main_stream = ctx.stream();
float *dst_dd,
const queue_ptr &main_stream) {
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
op()(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream); op()(ctx, src0, src1, dst, (const float *)src0->data, (const float *)src1->data, (float *)dst->data, main_stream);
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
op()(ctx, src0, src1, dst, (const sycl::half *)src0_dd, src1_dd, op()(ctx, src0, src1, dst, (const sycl::half *)src0->data, (const float *)src1->data,
(sycl::half *)dst_dd, main_stream); (sycl::half *)dst->data, main_stream);
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) { } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
op()(ctx, src0, src1, dst, (const sycl::half *)src0_dd, src1_dd, dst_dd, op()(ctx, src0, src1, dst, (const sycl::half *)src0->data, (const float *)src1->data, (float *)dst->data,
main_stream); main_stream);
} else if (src0->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) { } else if (src0->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) {
op()(ctx, src0, src1, dst, (const int32_t *)src0_dd, (const int32_t *)src1_dd, (int32_t *)dst_dd, op()(ctx, src0, src1, dst, (const int32_t *)src0->data, (const int32_t *)src1->data, (int32_t *)dst->data,
main_stream); main_stream);
} else if (src0->type == GGML_TYPE_I16 && dst->type == GGML_TYPE_I16) { } else if (src0->type == GGML_TYPE_I16 && dst->type == GGML_TYPE_I16) {
op()(ctx, src0, src1, dst, (const int16_t *)src0_dd, (const int16_t *)src1_dd, (int16_t *)dst_dd, op()(ctx, src0, src1, dst, (const int16_t *)src0->data, (const int16_t *)src1->data, (int16_t *)dst->data,
main_stream); main_stream);
} else { } else {
fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__, fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__,
@ -784,8 +776,4 @@ inline void ggml_sycl_op_bin_bcast(ggml_backend_sycl_context & ctx, const ggml_t
} }
bool gpu_has_xmx(sycl::device &dev); bool gpu_has_xmx(sycl::device &dev);
void ggml_sycl_op_flatten(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
const ggml_tensor *src1, ggml_tensor *dst,
const ggml_sycl_op_flatten_t op);
#endif // GGML_SYCL_COMMON_HPP #endif // GGML_SYCL_COMMON_HPP

View file

@ -16,9 +16,18 @@
#include <sycl/sycl.hpp> #include <sycl/sycl.hpp>
#include <sycl/half_type.hpp> #include <sycl/half_type.hpp>
#include <syclcompat/math.hpp> #include <syclcompat/math.hpp>
#include <oneapi/mkl.hpp>
#include <map> #include <map>
#ifdef GGML_SYCL_USE_INTEL_ONEMKL
#include <oneapi/mkl.hpp>
// Allow to use the same namespace for Intel oneMKL and oneMath
namespace oneapi {
namespace math = mkl;
}
#else
#include <oneapi/math.hpp>
#endif
#include "ggml.h" #include "ggml.h"
#if defined(__linux__) #if defined(__linux__)
@ -83,13 +92,32 @@ inline std::string get_device_backend_and_type(const sycl::device &device) {
} }
template <typename Ts> struct matrix_info_t { template <typename Ts> struct matrix_info_t {
oneapi::mkl::transpose transpose_info[2]; oneapi::math::transpose transpose_info[2];
Ts value_info[2]; Ts value_info[2];
std::int64_t size_info[3]; std::int64_t size_info[3];
std::int64_t ld_info[3]; std::int64_t ld_info[3];
std::int64_t groupsize_info; std::int64_t groupsize_info;
}; };
inline auto get_onemath_backend(sycl::queue& queue)
#if defined(GGML_SYCL_GENERIC) || defined(GGML_SYCL_USE_INTEL_ONEMKL)
-> sycl::queue&
#endif
{
// If the backend is known at compile-time, use oneMath backend_selector to use
// compile-time dispatching and avoid the need to dlopen libraries. Otherwise
// fallback to runtime dispatching.
#if defined(GGML_SYCL_NVIDIA)
return oneapi::math::backend_selector<oneapi::math::backend::cublas>{ queue };
#elif defined(GGML_SYCL_AMD)
return oneapi::math::backend_selector<oneapi::math::backend::rocblas>{ queue };
#elif defined(GGML_SYCL_GENERIC) || defined(GGML_SYCL_USE_INTEL_ONEMKL)
return queue;
#else
static_assert(false, "Unsupported backend");
#endif
}
namespace dpct namespace dpct
{ {
typedef sycl::queue *queue_ptr; typedef sycl::queue *queue_ptr;
@ -1686,26 +1714,18 @@ namespace dpct
namespace detail namespace detail
{ {
template <class Ta, class Tb, class Tc, class Ts> template <class Ta, class Tb, class Tc, class Ts>
inline void gemm_impl(sycl::queue &q, oneapi::mkl::transpose a_trans, inline void gemm_impl(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m,
oneapi::mkl::transpose b_trans, int m, int n, int k, int n, int k, const void * alpha, const void * a, int lda, const void * b, int ldb,
const void *alpha, const void *a, int lda, const void *b, const void * beta, void * c, int ldc) {
int ldb, const void *beta, void *c, int ldc) Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
{ Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q);
Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q); auto data_a = get_memory<const Ta>(a);
Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q); auto data_b = get_memory<const Tb>(b);
auto data_a = get_memory<const Ta>(a); auto data_c = get_memory<Tc>(c);
auto data_b = get_memory<const Tb>(b); oneapi::math::blas::column_major::gemm(get_onemath_backend(q), a_trans, b_trans, m, n, k, alpha_value, data_a,
auto data_c = get_memory<Tc>(c); lda, data_b, ldb, beta_value, data_c, ldc);
#ifdef GGML_SYCL_NVIDIA }
oneapi::mkl::blas::column_major::gemm(oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q },
a_trans, b_trans, m, n, k, alpha_value, data_a, lda, data_b, ldb,
beta_value, data_c, ldc);
#else
oneapi::mkl::blas::column_major::gemm(q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda, data_b, ldb,
beta_value, data_c, ldc);
#endif
}
template <typename VecT, class BinaryOperation, class = void> template <typename VecT, class BinaryOperation, class = void>
class vectorized_binary class vectorized_binary
@ -1735,7 +1755,7 @@ namespace dpct
}; };
template <class Ta, class Tb, class Tc, class Ts> template <class Ta, class Tb, class Tc, class Ts>
inline void gemm_batch_impl(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, inline void gemm_batch_impl(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans,
int m, int n, int k, const void * alpha, const void ** a, int lda, const void ** b, int m, int n, int k, const void * alpha, const void ** a, int lda, const void ** b,
int ldb, const void * beta, void ** c, int ldc, int batch_size, int ldb, const void * beta, void ** c, int ldc, int batch_size,
matrix_info_t<float> * matrix_info) { matrix_info_t<float> * matrix_info) {
@ -1754,48 +1774,28 @@ namespace dpct
matrix_info->ld_info[2] = ldc; matrix_info->ld_info[2] = ldc;
matrix_info->groupsize_info = batch_size; matrix_info->groupsize_info = batch_size;
#ifdef GGML_SYCL_NVIDIA sycl::event e = oneapi::math::blas::column_major::gemm_batch(
sycl::event e = oneapi::mkl::blas::column_major::gemm_batch( get_onemath_backend(q), matrix_info->transpose_info, matrix_info->transpose_info + 1,
oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q }, matrix_info->transpose_info, matrix_info->size_info, matrix_info->size_info + 1, matrix_info->size_info + 2,
matrix_info->transpose_info + 1, matrix_info->size_info, matrix_info->size_info + 1, reinterpret_cast<Ts *>(matrix_info->value_info), reinterpret_cast<const Ta **>(a), matrix_info->ld_info,
matrix_info->size_info + 2, reinterpret_cast<Ts *>(matrix_info->value_info), reinterpret_cast<const Tb **>(b), matrix_info->ld_info + 1,
reinterpret_cast<const Ta **>(a), matrix_info->ld_info, reinterpret_cast<const Tb **>(b), reinterpret_cast<Ts *>(matrix_info->value_info + 1), reinterpret_cast<Tc **>(c),
matrix_info->ld_info + 1, reinterpret_cast<Ts *>(matrix_info->value_info + 1), matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
reinterpret_cast<Tc **>(c), matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
#else
sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
q, matrix_info->transpose_info, matrix_info->transpose_info + 1, matrix_info->size_info,
matrix_info->size_info + 1, matrix_info->size_info + 2, reinterpret_cast<Ts *>(matrix_info->value_info),
reinterpret_cast<const Ta **>(a), matrix_info->ld_info, reinterpret_cast<const Tb **>(b),
matrix_info->ld_info + 1, reinterpret_cast<Ts *>(matrix_info->value_info + 1),
reinterpret_cast<Tc **>(c), matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
#endif
} }
template <class Ta, class Tb, class Tc, class Ts> template <class Ta, class Tb, class Tc, class Ts>
inline void inline void gemm_batch_impl(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans,
gemm_batch_impl(sycl::queue &q, oneapi::mkl::transpose a_trans, int m, int n, int k, const void * alpha, const void * a, int lda,
oneapi::mkl::transpose b_trans, int m, int n, long long int stride_a, const void * b, int ldb, long long int stride_b,
int k, const void *alpha, const void *a, int lda, const void * beta, void * c, int ldc, long long int stride_c, int batch_size) {
long long int stride_a, const void *b, int ldb,
long long int stride_b, const void *beta, void *c,
int ldc, long long int stride_c, int batch_size)
{
Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q); Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q); Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q);
auto data_a = get_memory<const Ta>(a); auto data_a = get_memory<const Ta>(a);
auto data_b = get_memory<const Tb>(b); auto data_b = get_memory<const Tb>(b);
auto data_c = get_memory<Tc>(c); auto data_c = get_memory<Tc>(c);
#ifdef GGML_SYCL_NVIDIA oneapi::math::blas::column_major::gemm_batch(get_onemath_backend(q), a_trans, b_trans, m, n, k, alpha_value,
oneapi::mkl::blas::column_major::gemm_batch( data_a, lda, stride_a, data_b, ldb, stride_b, beta_value,
oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q }, a_trans, b_trans, m, n, k, data_c, ldc, stride_c, batch_size);
alpha_value, data_a, lda, stride_a, data_b, ldb, stride_b, beta_value, data_c, ldc, stride_c,
batch_size);
#else
oneapi::mkl::blas::column_major::gemm_batch(q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda,
stride_a, data_b, ldb, stride_b, beta_value, data_c, ldc,
stride_c, batch_size);
#endif
} }
} // namespace detail } // namespace detail
@ -2259,13 +2259,10 @@ namespace dpct
sycl::range<3>(x, y, 1), direction); sycl::range<3>(x, y, 1), direction);
} }
inline void gemm(sycl::queue &q, oneapi::mkl::transpose a_trans, inline void gemm(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m, int n,
oneapi::mkl::transpose b_trans, int m, int n, int k, int k, const void * alpha, const void * a, library_data_t a_type, int lda, const void * b,
const void *alpha, const void *a, library_data_t a_type, library_data_t b_type, int ldb, const void * beta, void * c, library_data_t c_type, int ldc,
int lda, const void *b, library_data_t b_type, int ldb, library_data_t scaling_type) {
const void *beta, void *c, library_data_t c_type, int ldc,
library_data_t scaling_type)
{
if (scaling_type == library_data_t::real_float && if (scaling_type == library_data_t::real_float &&
c_type == library_data_t::complex_float) c_type == library_data_t::complex_float)
{ {
@ -2329,9 +2326,8 @@ namespace dpct
library_data_t::real_bfloat16, library_data_t::real_bfloat16, library_data_t::real_bfloat16, library_data_t::real_bfloat16,
library_data_t::real_float, library_data_t::real_float): library_data_t::real_float, library_data_t::real_float):
{ {
detail::gemm_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float, detail::gemm_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, float, float>(
float>(q, a_trans, b_trans, m, n, k, alpha, a, lda, b, q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
ldb, beta, c, ldc);
break; break;
} }
case detail::get_type_combination_id( case detail::get_type_combination_id(
@ -2369,8 +2365,7 @@ namespace dpct
library_data_t::real_bfloat16, library_data_t::real_bfloat16, library_data_t::real_bfloat16, library_data_t::real_bfloat16,
library_data_t::real_bfloat16, library_data_t::real_float): library_data_t::real_bfloat16, library_data_t::real_float):
{ {
detail::gemm_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, detail::gemm_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, oneapi::math::bfloat16, float>(
oneapi::mkl::bfloat16, float>(
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
break; break;
} }
@ -2390,7 +2385,7 @@ namespace dpct
default: default:
throw std::runtime_error("the combination of data type is unsupported"); throw std::runtime_error("the combination of data type is unsupported");
} }
} // gemm() } // gemm()
/// Computes a batch of matrix-matrix product with general matrices. /// Computes a batch of matrix-matrix product with general matrices.
/// \param [in] q The queue where the routine should be executed. /// \param [in] q The queue where the routine should be executed.
@ -2412,7 +2407,7 @@ namespace dpct
/// \param [in] ldc Leading dimension of C. /// \param [in] ldc Leading dimension of C.
/// \param [in] batch_size Specifies the number of matrix multiply operations to perform. /// \param [in] batch_size Specifies the number of matrix multiply operations to perform.
/// \param [in] scaling_type Data type of the scaling factors. /// \param [in] scaling_type Data type of the scaling factors.
inline void gemm_batch(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m, inline void gemm_batch(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m,
int n, int k, const void * alpha, const void * a[], library_data_t a_type, int lda, int n, int k, const void * alpha, const void * a[], library_data_t a_type, int lda,
const void * b[], library_data_t b_type, int ldb, const void * beta, void * c[], const void * b[], library_data_t b_type, int ldb, const void * beta, void * c[],
library_data_t c_type, int ldc, int batch_size, library_data_t scaling_type, library_data_t c_type, int ldc, int batch_size, library_data_t scaling_type,
@ -2450,7 +2445,7 @@ namespace dpct
library_data_t::real_bfloat16, library_data_t::real_bfloat16, library_data_t::real_bfloat16, library_data_t::real_bfloat16,
library_data_t::real_bfloat16, library_data_t::real_float): library_data_t::real_bfloat16, library_data_t::real_float):
{ {
detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float>( detail::gemm_batch_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, oneapi::math::bfloat16, float>(
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info); q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
break; break;
} }
@ -2458,7 +2453,7 @@ namespace dpct
library_data_t::real_bfloat16, library_data_t::real_bfloat16, library_data_t::real_bfloat16, library_data_t::real_bfloat16,
library_data_t::real_float, library_data_t::real_float): library_data_t::real_float, library_data_t::real_float):
{ {
detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float, float>( detail::gemm_batch_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, float, float>(
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info); q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
break; break;
} }
@ -2534,15 +2529,11 @@ namespace dpct
/// \param [in] stride_c Stride between the different C matrices. /// \param [in] stride_c Stride between the different C matrices.
/// \param [in] batch_size Specifies the number of matrix multiply operations to perform. /// \param [in] batch_size Specifies the number of matrix multiply operations to perform.
/// \param [in] scaling_type Data type of the scaling factors. /// \param [in] scaling_type Data type of the scaling factors.
inline void gemm_batch(sycl::queue &q, oneapi::mkl::transpose a_trans, inline void gemm_batch(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m,
oneapi::mkl::transpose b_trans, int m, int n, int k, int n, int k, const void * alpha, const void * a, library_data_t a_type, int lda,
const void *alpha, const void *a, library_data_t a_type, long long int stride_a, const void * b, library_data_t b_type, int ldb,
int lda, long long int stride_a, const void *b, long long int stride_b, const void * beta, void * c, library_data_t c_type, int ldc,
library_data_t b_type, int ldb, long long int stride_b, long long int stride_c, int batch_size, library_data_t scaling_type) {
const void *beta, void *c, library_data_t c_type,
int ldc, long long int stride_c, int batch_size,
library_data_t scaling_type)
{
if (scaling_type == library_data_t::real_float && if (scaling_type == library_data_t::real_float &&
c_type == library_data_t::complex_float) c_type == library_data_t::complex_float)
{ {
@ -2611,20 +2602,18 @@ namespace dpct
library_data_t::real_bfloat16, library_data_t::real_bfloat16, library_data_t::real_bfloat16, library_data_t::real_bfloat16,
library_data_t::real_bfloat16, library_data_t::real_float): library_data_t::real_bfloat16, library_data_t::real_float):
{ {
detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, detail::gemm_batch_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, oneapi::math::bfloat16, float>(
oneapi::mkl::bfloat16, float>( q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c,
q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, batch_size);
beta, c, ldc, stride_c, batch_size);
break; break;
} }
case detail::get_type_combination_id( case detail::get_type_combination_id(
library_data_t::real_bfloat16, library_data_t::real_bfloat16, library_data_t::real_bfloat16, library_data_t::real_bfloat16,
library_data_t::real_float, library_data_t::real_float): library_data_t::real_float, library_data_t::real_float):
{ {
detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float, detail::gemm_batch_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, float, float>(
float>(q, a_trans, b_trans, m, n, k, alpha, a, lda, q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c,
stride_a, b, ldb, stride_b, beta, c, ldc, batch_size);
stride_c, batch_size);
break; break;
} }
#endif #endif

View file

@ -509,497 +509,409 @@ static void pad_f32_sycl(const float *x, float *dst, const int ne00,
}); });
} }
inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
ggml_tensor *dst, const float *src0_dd,
const float *src1_dd, float *dst_dd,
const queue_ptr &main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
silu_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
GGML_UNUSED(src1); silu_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream);
GGML_UNUSED(dst);
GGML_UNUSED(src1_dd);
GGML_UNUSED(ctx);
} }
inline void ggml_sycl_op_gelu(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, inline void ggml_sycl_op_gelu(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
ggml_tensor *dst, const float *src0_dd,
const float *src1_dd, float *dst_dd,
const queue_ptr &main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
gelu_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream); gelu_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream);
GGML_UNUSED(src1);
GGML_UNUSED(dst);
GGML_UNUSED(src1_dd);
GGML_UNUSED(ctx);
}
inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
const ggml_tensor *src1, ggml_tensor *dst,
const float *src0_dd, const float *src1_dd,
float *dst_dd,
const queue_ptr &main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
gelu_quick_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
GGML_UNUSED(src1);
GGML_UNUSED(dst);
GGML_UNUSED(src1_dd);
GGML_UNUSED(ctx);
} }
inline void ggml_sycl_op_tanh(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
ggml_tensor *dst, const float *src0_dd,
const float *src1_dd, float *dst_dd,
const queue_ptr &main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32);
tanh_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream); dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
GGML_UNUSED(src1); gelu_quick_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream);
GGML_UNUSED(dst);
GGML_UNUSED(src1_dd);
GGML_UNUSED(ctx);
} }
inline void ggml_sycl_op_relu(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, inline void ggml_sycl_op_tanh(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
ggml_tensor *dst, const float *src0_dd,
const float *src1_dd, float *dst_dd,
const queue_ptr &main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32);
dpct::queue_ptr main_stream = ctx.stream();
relu_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream); SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
GGML_UNUSED(src1); float * dst_dd = static_cast<float *>(dst->data);
GGML_UNUSED(dst); tanh_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream);
GGML_UNUSED(src1_dd);
GGML_UNUSED(ctx);
} }
inline void ggml_sycl_op_hardsigmoid(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, inline void ggml_sycl_op_relu(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
const ggml_tensor *src1, ggml_tensor *dst,
const float *src0_dd, const float *src1_dd,
float *dst_dd,
const queue_ptr &main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
hardsigmoid_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream); relu_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream);
GGML_UNUSED(src1);
GGML_UNUSED(dst);
GGML_UNUSED(src1_dd);
GGML_UNUSED(ctx);
} }
inline void ggml_sycl_op_hardswish(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, inline void ggml_sycl_op_hardsigmoid(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
const ggml_tensor *src1, ggml_tensor *dst,
const float *src0_dd, const float *src1_dd,
float *dst_dd, const queue_ptr &main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
hardswish_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream); hardsigmoid_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream);
GGML_UNUSED(src1);
GGML_UNUSED(dst);
GGML_UNUSED(src1_dd);
GGML_UNUSED(ctx);
} }
inline void ggml_sycl_op_exp(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, inline void ggml_sycl_op_hardswish(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
const ggml_tensor *src1, ggml_tensor *dst,
const float *src0_dd, const float *src1_dd,
float *dst_dd, const queue_ptr &main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
exp_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream); hardswish_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream);
GGML_UNUSED(src1);
GGML_UNUSED(dst);
GGML_UNUSED(src1_dd);
GGML_UNUSED(ctx);
} }
inline void ggml_sycl_op_log(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, inline void ggml_sycl_op_exp(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
const ggml_tensor *src1, ggml_tensor *dst,
const float *src0_dd, const float *src1_dd,
float *dst_dd, const queue_ptr &main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
log_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream); exp_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream);
GGML_UNUSED(src1);
GGML_UNUSED(dst);
GGML_UNUSED(src1_dd);
GGML_UNUSED(ctx);
} }
inline void ggml_sycl_op_sigmoid(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, inline void ggml_sycl_op_log(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
const ggml_tensor *src1, ggml_tensor *dst,
const float *src0_dd, const float *src1_dd,
float *dst_dd, const queue_ptr &main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
sigmoid_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream); log_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream);
GGML_UNUSED(src1);
GGML_UNUSED(dst);
GGML_UNUSED(src1_dd);
GGML_UNUSED(ctx);
} }
inline void ggml_sycl_op_sqrt(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, inline void ggml_sycl_op_sigmoid(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
const ggml_tensor *src1, ggml_tensor *dst,
const float *src0_dd, const float *src1_dd,
float *dst_dd, const queue_ptr &main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
sqrt_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
GGML_UNUSED(src1); sigmoid_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream);
GGML_UNUSED(dst);
GGML_UNUSED(src1_dd);
GGML_UNUSED(ctx);
} }
inline void ggml_sycl_op_sin(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, inline void ggml_sycl_op_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
const ggml_tensor *src1, ggml_tensor *dst,
const float *src0_dd, const float *src1_dd,
float *dst_dd, const queue_ptr &main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
sin_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
GGML_UNUSED(src1); sqrt_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream);
GGML_UNUSED(dst);
GGML_UNUSED(src1_dd);
GGML_UNUSED(ctx);
} }
inline void ggml_sycl_op_cos(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, inline void ggml_sycl_op_sin(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
const ggml_tensor *src1, ggml_tensor *dst,
const float *src0_dd, const float *src1_dd,
float *dst_dd, const queue_ptr &main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
cos_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream); sin_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream);
GGML_UNUSED(src1);
GGML_UNUSED(dst);
GGML_UNUSED(src1_dd);
GGML_UNUSED(ctx);
} }
inline void ggml_sycl_op_step(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, inline void ggml_sycl_op_cos(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
const ggml_tensor *src1, ggml_tensor *dst,
const float *src0_dd, const float *src1_dd,
float *dst_dd, const queue_ptr &main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
step_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream); cos_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream);
GGML_UNUSED(src1);
GGML_UNUSED(dst);
GGML_UNUSED(src1_dd);
GGML_UNUSED(ctx);
} }
inline void ggml_sycl_op_neg(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, inline void ggml_sycl_op_step(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
const ggml_tensor *src1, ggml_tensor *dst,
const float *src0_dd, const float *src1_dd,
float *dst_dd, const queue_ptr &main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
neg_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream); step_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream);
GGML_UNUSED(src1);
GGML_UNUSED(dst);
GGML_UNUSED(src1_dd);
GGML_UNUSED(ctx);
} }
inline void ggml_sycl_op_leaky_relu(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
const ggml_tensor *src1, ggml_tensor *dst,
const float *src0_dd, const float *src1_dd,
float *dst_dd,
const queue_ptr &main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32); inline void ggml_sycl_op_neg(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
neg_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream);
}
inline void ggml_sycl_op_leaky_relu(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
float negative_slope; float negative_slope;
memcpy(&negative_slope, dst->op_params, sizeof(float)); memcpy(&negative_slope, dst->op_params, sizeof(float));
leaky_relu_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), negative_slope, main_stream); leaky_relu_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), negative_slope, main_stream);
GGML_UNUSED(src1);
GGML_UNUSED(dst);
GGML_UNUSED(src1_dd);
GGML_UNUSED(ctx);
} }
inline void ggml_sycl_op_sqr(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, inline void ggml_sycl_op_sqr(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
ggml_tensor *dst, const float *src0_dd,
const float *src1_dd, float *dst_dd,
const queue_ptr &main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
sqr_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream); sqr_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream);
GGML_UNUSED(src1);
GGML_UNUSED(dst);
GGML_UNUSED(src1_dd);
GGML_UNUSED(ctx);
} }
inline void ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, inline void ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
const ggml_tensor *src1, ggml_tensor *dst,
const float *src0_dd, const float *src1_dd,
float *dst_dd,
const queue_ptr &main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
const float sf0 = (float)dst->ne[0]/src0->ne[0]; const float sf0 = (float)dst->ne[0]/dst->src[0]->ne[0];
const float sf1 = (float)dst->ne[1]/src0->ne[1]; const float sf1 = (float)dst->ne[1]/dst->src[0]->ne[1];
const float sf2 = (float)dst->ne[2]/src0->ne[2]; const float sf2 = (float)dst->ne[2]/dst->src[0]->ne[2];
const float sf3 = (float)dst->ne[3]/src0->ne[3]; const float sf3 = (float)dst->ne[3]/dst->src[0]->ne[3];
upscale_f32_sycl(src0_dd, dst_dd, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], upscale_f32_sycl(src0_dd, dst_dd, dst->src[0]->nb[0], dst->src[0]->nb[1], dst->src[0]->nb[2], dst->src[0]->nb[3],
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3, dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3,
main_stream); main_stream);
GGML_UNUSED(src1);
GGML_UNUSED(dst);
GGML_UNUSED(src1_dd);
GGML_UNUSED(ctx);
} }
inline void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, inline void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
ggml_tensor *dst, const float *src0_dd,
const float *src1_dd, float *dst_dd,
const queue_ptr &main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors GGML_ASSERT(dst->src[0]->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
pad_f32_sycl(src0_dd, dst_dd, pad_f32_sycl(src0_dd, dst_dd,
src0->ne[0], src0->ne[1], src0->ne[2], dst->src[0]->ne[0], dst->src[0]->ne[1], dst->src[0]->ne[2],
dst->ne[0], dst->ne[1], dst->ne[2], main_stream); dst->ne[0], dst->ne[1], dst->ne[2], main_stream);
GGML_UNUSED(src1);
GGML_UNUSED(dst);
GGML_UNUSED(src1_dd);
GGML_UNUSED(ctx);
} }
inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
ggml_tensor *dst, const float *src0_dd,
const float *src1_dd, float *dst_dd,
const queue_ptr &main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32); GGML_ASSERT(dst->src[1]->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_ASSERT(dst->ne[3] == 1); // just 3D tensors supported GGML_ASSERT(dst->ne[3] == 1); // just 3D tensors supported
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
const float * src1_dd = static_cast<const float*>(dst->src[1]->data);
float * dst_dd = static_cast<float *>(dst->data);
int nb1 = dst->op_params[0] / 4; // 4 bytes of float32 int nb1 = dst->op_params[0] / 4; // 4 bytes of float32
int nb2 = dst->op_params[1] / 4; // 4 bytes of float32 int nb2 = dst->op_params[1] / 4; // 4 bytes of float32
// int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused // int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused
int offset = dst->op_params[3] / 4; // offset in bytes int offset = dst->op_params[3] / 4; // offset in bytes
acc_f32_sycl(src0_dd, src1_dd, dst_dd, ggml_nelements(dst), src1->ne[0], src1->ne[1], src1->ne[2], nb1, nb2, offset, main_stream); acc_f32_sycl(src0_dd, src1_dd, dst_dd, ggml_nelements(dst), dst->src[1]->ne[0], dst->src[1]->ne[1], dst->src[1]->ne[2], nb1, nb2, offset, main_stream);
GGML_UNUSED(dst);
GGML_UNUSED(ctx);
} }
inline void ggml_sycl_op_add(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, inline void ggml_sycl_op_add(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
ggml_tensor *dst, const float *src0_dd,
const float *src1_dd, float *dst_dd,
const queue_ptr &main_stream) {
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_add>>(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream); ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_add>>(ctx, dst->src[0], dst->src[1], dst);
} }
inline void ggml_sycl_op_sub(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, inline void ggml_sycl_op_sub(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
ggml_tensor *dst, const float *src0_dd,
const float *src1_dd, float *dst_dd,
const queue_ptr &main_stream) {
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_sub>>(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream); ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_sub>>(ctx, dst->src[0], dst->src[1], dst);
} }
inline void ggml_sycl_op_mul(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, inline void ggml_sycl_op_mul(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
ggml_tensor *dst, const float *src0_dd,
const float *src1_dd, float *dst_dd,
const queue_ptr &main_stream) {
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_mul>>(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream); ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_mul>>(ctx, dst->src[0], dst->src[1], dst);
} }
inline void ggml_sycl_op_div(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, inline void ggml_sycl_op_div(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
ggml_tensor *dst, const float *src0_dd,
const float *src1_dd, float *dst_dd,
const queue_ptr &main_stream) {
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_div>>(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream); ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_div>>(ctx, dst->src[0], dst->src[1], dst);
} }
void ggml_sycl_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { void ggml_sycl_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__); GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_sqrt); ggml_sycl_op_sqrt(ctx, dst);
GGML_SYCL_DEBUG("call %s done\n", __func__); GGML_SYCL_DEBUG("call %s done\n", __func__);
} }
void ggml_sycl_sin(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { void ggml_sycl_sin(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__); GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_sin); ggml_sycl_op_sin(ctx, dst);
GGML_SYCL_DEBUG("call %s done\n", __func__); GGML_SYCL_DEBUG("call %s done\n", __func__);
} }
void ggml_sycl_cos(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { void ggml_sycl_cos(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__); GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_cos); ggml_sycl_op_cos(ctx, dst);
GGML_SYCL_DEBUG("call %s done\n", __func__); GGML_SYCL_DEBUG("call %s done\n", __func__);
} }
void ggml_sycl_acc(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { void ggml_sycl_acc(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__); GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_acc); ggml_sycl_op_acc(ctx, dst);
GGML_SYCL_DEBUG("call %s done\n", __func__); GGML_SYCL_DEBUG("call %s done\n", __func__);
} }
void ggml_sycl_gelu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { void ggml_sycl_gelu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__); GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_gelu); ggml_sycl_op_gelu(ctx, dst);
GGML_SYCL_DEBUG("call %s done\n", __func__); GGML_SYCL_DEBUG("call %s done\n", __func__);
} }
void ggml_sycl_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { void ggml_sycl_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__); GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_silu); ggml_sycl_op_silu(ctx, dst);
GGML_SYCL_DEBUG("call %s done\n", __func__); GGML_SYCL_DEBUG("call %s done\n", __func__);
} }
void ggml_sycl_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { void ggml_sycl_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__); GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_gelu_quick); ggml_sycl_op_gelu_quick(ctx, dst);
GGML_SYCL_DEBUG("call %s done\n", __func__); GGML_SYCL_DEBUG("call %s done\n", __func__);
} }
void ggml_sycl_tanh(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { void ggml_sycl_tanh(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__); GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_tanh); ggml_sycl_op_tanh(ctx, dst);
GGML_SYCL_DEBUG("call %s done\n", __func__); GGML_SYCL_DEBUG("call %s done\n", __func__);
} }
void ggml_sycl_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { void ggml_sycl_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__); GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_relu); ggml_sycl_op_relu(ctx, dst);
GGML_SYCL_DEBUG("call %s done\n", __func__); GGML_SYCL_DEBUG("call %s done\n", __func__);
} }
void ggml_sycl_sigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { void ggml_sycl_sigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__); GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_sigmoid); ggml_sycl_op_sigmoid(ctx, dst);
GGML_SYCL_DEBUG("call %s done\n", __func__); GGML_SYCL_DEBUG("call %s done\n", __func__);
} }
void ggml_sycl_hardsigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { void ggml_sycl_hardsigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__); GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_hardsigmoid); ggml_sycl_op_hardsigmoid(ctx, dst);
GGML_SYCL_DEBUG("call %s done\n", __func__); GGML_SYCL_DEBUG("call %s done\n", __func__);
} }
void ggml_sycl_hardswish(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { void ggml_sycl_hardswish(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__); GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_hardswish); ggml_sycl_op_hardswish(ctx, dst);
GGML_SYCL_DEBUG("call %s done\n", __func__); GGML_SYCL_DEBUG("call %s done\n", __func__);
} }
void ggml_sycl_exp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { void ggml_sycl_exp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__); GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_exp); ggml_sycl_op_exp(ctx, dst);
GGML_SYCL_DEBUG("call %s done\n", __func__); GGML_SYCL_DEBUG("call %s done\n", __func__);
} }
void ggml_sycl_log(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { void ggml_sycl_log(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__); GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_log); ggml_sycl_op_log(ctx, dst);
GGML_SYCL_DEBUG("call %s done\n", __func__); GGML_SYCL_DEBUG("call %s done\n", __func__);
} }
void ggml_sycl_neg(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { void ggml_sycl_neg(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__); GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_neg); ggml_sycl_op_neg(ctx, dst);
GGML_SYCL_DEBUG("call %s done\n", __func__); GGML_SYCL_DEBUG("call %s done\n", __func__);
} }
void ggml_sycl_step(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { void ggml_sycl_step(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__); GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_step); ggml_sycl_op_step(ctx, dst);
GGML_SYCL_DEBUG("call %s done\n", __func__); GGML_SYCL_DEBUG("call %s done\n", __func__);
} }
void ggml_sycl_leaky_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { void ggml_sycl_leaky_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__); GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_leaky_relu); ggml_sycl_op_leaky_relu(ctx, dst);
GGML_SYCL_DEBUG("call %s done\n", __func__); GGML_SYCL_DEBUG("call %s done\n", __func__);
} }
void ggml_sycl_sqr(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { void ggml_sycl_sqr(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__); GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_sqr); ggml_sycl_op_sqr(ctx, dst);
GGML_SYCL_DEBUG("call %s done\n", __func__); GGML_SYCL_DEBUG("call %s done\n", __func__);
} }
void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__); GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_upscale); ggml_sycl_op_upscale(ctx, dst);
GGML_SYCL_DEBUG("call %s done\n", __func__); GGML_SYCL_DEBUG("call %s done\n", __func__);
} }
void ggml_sycl_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { void ggml_sycl_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__); GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_pad); ggml_sycl_op_pad(ctx, dst);
GGML_SYCL_DEBUG("call %s done\n", __func__); GGML_SYCL_DEBUG("call %s done\n", __func__);
} }
@ -1007,24 +919,24 @@ void ggml_sycl_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
void ggml_sycl_add(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { void ggml_sycl_add(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__); GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_add); ggml_sycl_op_add(ctx, dst);
GGML_SYCL_DEBUG("call %s done\n", __func__); GGML_SYCL_DEBUG("call %s done\n", __func__);
} }
void ggml_sycl_sub(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { void ggml_sycl_sub(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__); GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_sub); ggml_sycl_op_sub(ctx, dst);
GGML_SYCL_DEBUG("call %s done\n", __func__); GGML_SYCL_DEBUG("call %s done\n", __func__);
} }
void ggml_sycl_mul(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { void ggml_sycl_mul(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__); GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_mul); ggml_sycl_op_mul(ctx, dst);
GGML_SYCL_DEBUG("call %s done\n", __func__); GGML_SYCL_DEBUG("call %s done\n", __func__);
} }
void ggml_sycl_div(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { void ggml_sycl_div(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__); GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_div); ggml_sycl_op_div(ctx, dst);
GGML_SYCL_DEBUG("call %s done\n", __func__); GGML_SYCL_DEBUG("call %s done\n", __func__);
} }

View file

@ -257,50 +257,54 @@ static void get_rows_sycl_float(ggml_backend_sycl_context & ctx, const ggml_tens
GGML_UNUSED(ctx); GGML_UNUSED(ctx);
} }
void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
const ggml_tensor *src1, ggml_tensor *dst,
const float *src0_d, const float *src1_d,
float *dst_d, const queue_ptr &stream) {
GGML_ASSERT(src1->type == GGML_TYPE_I32); GGML_ASSERT(dst->src[1]->type == GGML_TYPE_I32);
GGML_ASSERT(dst->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type)); GGML_ASSERT(dst->src[0]->nb[0] == ggml_type_size(dst->src[0]->type));
GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type)); GGML_ASSERT(dst->src[1]->nb[0] == ggml_type_size(dst->src[1]->type));
GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type)); GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type));
const int32_t * src1_i32 = (const int32_t *) src1_d; const int32_t * src1_i32 = (const int32_t *) dst->src[1]->data;
/* TODO: Refactor and remove duplicates */
switch (src0->type) { switch (dst->src[0]->type) {
case GGML_TYPE_F16: case GGML_TYPE_F16:
get_rows_sycl_float(ctx, src0, src1, dst, (const sycl::half *)src0_d, get_rows_sycl_float(ctx, dst->src[0], dst->src[1], dst, (const sycl::half *)dst->src[0]->data,
src1_i32, dst_d, stream); src1_i32, (float *)dst->data, ctx.stream());
break; break;
case GGML_TYPE_F32: case GGML_TYPE_F32:
get_rows_sycl_float(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream); get_rows_sycl_float(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data,
src1_i32, (float *)dst->data, ctx.stream());
break; break;
case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_0:
if (ctx.opt_feature.reorder && dst->op == GGML_OP_MUL_MAT) { if (ctx.opt_feature.reorder && dst->op == GGML_OP_MUL_MAT) {
get_rows_sycl_reorder<QK4_0, QR4_0, dequantize_q4_0_reorder>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream); get_rows_sycl_reorder<QK4_0, QR4_0, dequantize_q4_0_reorder>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data,
src1_i32, (float *)dst->data, ctx.stream());
} else { } else {
get_rows_sycl<QK4_0, QR4_0, dequantize_q4_0>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream); get_rows_sycl<QK4_0, QR4_0, dequantize_q4_0>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data,
src1_i32, (float *)dst->data, ctx.stream());
} }
break; break;
case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_1:
get_rows_sycl<QK4_1, QR4_1, dequantize_q4_1>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream); get_rows_sycl<QK4_1, QR4_1, dequantize_q4_1>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data,
src1_i32, (float *)dst->data, ctx.stream());
break; break;
case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_0:
get_rows_sycl<QK5_0, QR5_0, dequantize_q5_0>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream); get_rows_sycl<QK5_0, QR5_0, dequantize_q5_0>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data,
src1_i32, (float *)dst->data, ctx.stream());
break; break;
case GGML_TYPE_Q5_1: case GGML_TYPE_Q5_1:
get_rows_sycl<QK5_1, QR5_1, dequantize_q5_1>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream); get_rows_sycl<QK5_1, QR5_1, dequantize_q5_1>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data,
src1_i32, (float *)dst->data, ctx.stream());
break; break;
case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_0:
get_rows_sycl<QK8_0, QR8_0, dequantize_q8_0>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream); get_rows_sycl<QK8_0, QR8_0, dequantize_q8_0>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data,
src1_i32, (float *)dst->data, ctx.stream());
break; break;
default: default:
// TODO: k-quants // TODO: k-quants
GGML_LOG_ERROR("%s: unsupported type: %s\n", __func__, ggml_type_name(src0->type)); GGML_LOG_ERROR("%s: unsupported type: %s\n", __func__, ggml_type_name(dst->src[0]->type));
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");
} }
} }

View file

@ -15,9 +15,6 @@
#include "common.hpp" #include "common.hpp"
void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor *dst);
const ggml_tensor *src1, ggml_tensor *dst,
const float *src0_d, const float *src1_d,
float *dst_d, const queue_ptr &stream);
#endif // GGML_SYCL_GETROWS_HPP #endif // GGML_SYCL_GETROWS_HPP

View file

@ -1988,16 +1988,8 @@ catch (sycl::exception const &exc) {
std::exit(1); std::exit(1);
} }
static void ggml_sycl_op_repeat(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, static void ggml_sycl_op_repeat(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
const ggml_tensor *src1, ggml_tensor *dst, ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_repeat>>(ctx, dst, dst->src[0], dst);
const float *src0_d, const float *src1_d,
float *dst_d,
const queue_ptr &main_stream) {
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_repeat>>(ctx, dst, src0, dst, nullptr, src0_d, dst_d, main_stream);
GGML_UNUSED(src1);
GGML_UNUSED(src1_d);
} }
@ -2067,8 +2059,8 @@ inline void ggml_sycl_op_mul_mat_sycl(
const sycl::half alpha_f16 = 1.0f; const sycl::half alpha_f16 = 1.0f;
const sycl::half beta_f16 = 0.0f; const sycl::half beta_f16 = 0.0f;
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm( SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm(
*stream, oneapi::mkl::transpose::trans, *stream, oneapi::math::transpose::trans,
oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10, oneapi::math::transpose::nontrans, row_diff, src1_ncols, ne10,
&alpha_f16, src0_ptr, dpct::library_data_t::real_half, ne00, &alpha_f16, src0_ptr, dpct::library_data_t::real_half, ne00,
src1_ptr, dpct::library_data_t::real_half, ne10, &beta_f16, src1_ptr, dpct::library_data_t::real_half, ne10, &beta_f16,
dst_f16.get(), dpct::library_data_t::real_half, ldc, dst_f16.get(), dpct::library_data_t::real_half, ldc,
@ -2105,17 +2097,10 @@ inline void ggml_sycl_op_mul_mat_sycl(
#if !GGML_SYCL_DNNL #if !GGML_SYCL_DNNL
const float alpha = 1.0f; const float alpha = 1.0f;
const float beta = 0.0f; const float beta = 0.0f;
# ifdef GGML_SYCL_NVIDIA SYCL_CHECK(CHECK_TRY_ERROR(oneapi::math::blas::column_major::gemm(
SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm( get_onemath_backend(*stream), oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, row_diff,
oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ *stream }, oneapi::mkl::transpose::trans, src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10,
oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i, dpct::get_value(&beta, *stream), dst_dd_i, ldc)));
ne00, src1_ddf1_i, ne10, dpct::get_value(&beta, *stream), dst_dd_i, ldc)));
# else
SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm(
*stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10, dpct::get_value(&beta, *stream),
dst_dd_i, ldc)));
# endif
#else #else
DnnlGemmWrapper::row_gemm(ctx, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i, DnnlGemmWrapper::row_gemm(ctx, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i,
DnnlGemmWrapper::to_dt<float>(), src0_ddf_i, DnnlGemmWrapper::to_dt<float>(), DnnlGemmWrapper::to_dt<float>(), src0_ddf_i, DnnlGemmWrapper::to_dt<float>(),
@ -2132,13 +2117,14 @@ catch (sycl::exception const &exc) {
std::exit(1); std::exit(1);
} }
static void ggml_sycl_op_pool2d(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, static void ggml_sycl_op_pool2d(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
const ggml_tensor *src1, ggml_tensor *dst,
const float *src0_dd, const float *src1_dd,
float *dst_dd, const queue_ptr &main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
const int32_t * opts = (const int32_t *)dst->op_params; const int32_t * opts = (const int32_t *)dst->op_params;
enum ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]); enum ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]);
@ -2149,8 +2135,8 @@ static void ggml_sycl_op_pool2d(ggml_backend_sycl_context & ctx, const ggml_tens
const int p0 = opts[5]; const int p0 = opts[5];
const int p1 = opts[6]; const int p1 = opts[6];
const int64_t IH = src0->ne[1]; const int64_t IH = dst->src[0]->ne[1];
const int64_t IW = src0->ne[0]; const int64_t IW = dst->src[0]->ne[0];
const int64_t N = dst->ne[3]; const int64_t N = dst->ne[3];
const int64_t OC = dst->ne[2]; const int64_t OC = dst->ne[2];
@ -2169,163 +2155,125 @@ static void ggml_sycl_op_pool2d(ggml_backend_sycl_context & ctx, const ggml_tens
parallel_elements, src0_dd, dst_dd, op, parallel_elements, src0_dd, dst_dd, op,
item_ct1); item_ct1);
}); });
GGML_UNUSED(src1);
GGML_UNUSED(src1_dd);
GGML_UNUSED(ctx);
} }
inline void ggml_sycl_op_sum(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, inline void ggml_sycl_op_sum(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
const ggml_tensor *src1, ggml_tensor *dst, GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
const float *src0_dd, const float *src1_dd,
float *dst_dd,
const queue_ptr &main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
const int64_t ne = ggml_nelements(src0); const int64_t ne = ggml_nelements(dst->src[0]);
sum_rows_f32_sycl(src0_dd, dst_dd, ne, 1, main_stream); sum_rows_f32_sycl(src0_dd, dst_dd, ne, 1, main_stream);
GGML_UNUSED(src1);
GGML_UNUSED(dst);
GGML_UNUSED(src1_dd);
GGML_UNUSED(ctx);
} }
inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
const ggml_tensor *src1, ggml_tensor *dst,
const float *src0_dd, const float *src1_dd,
float *dst_dd,
const queue_ptr &main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
const int64_t ncols = src0->ne[0]; const int64_t ncols = dst->src[0]->ne[0];
const int64_t nrows = ggml_nrows(src0); const int64_t nrows = ggml_nrows(dst->src[0]);
sum_rows_f32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream); sum_rows_f32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream);
GGML_UNUSED(src1);
GGML_UNUSED(dst);
GGML_UNUSED(src1_dd);
GGML_UNUSED(ctx);
} }
inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
const ggml_tensor *src1, ggml_tensor *dst, GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
const float *src0_dd, const float *src1_dd, GGML_ASSERT(dst->type == GGML_TYPE_I32);
float *dst_dd, dpct::queue_ptr main_stream = ctx.stream();
const queue_ptr &main_stream) { SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
int32_t * dst_dd = static_cast<int32_t *>(dst->data);
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_I32);
const int64_t ncols = src0->ne[0]; const int64_t ncols = dst->src[0]->ne[0];
const int64_t nrows = ggml_nrows(src0); const int64_t nrows = ggml_nrows(dst->src[0]);
enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0]; enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
argsort_f32_i32_sycl(src0_dd, (int *)dst_dd, ncols, nrows, order, main_stream); argsort_f32_i32_sycl(src0_dd, (int *) dst_dd, ncols, nrows, order, main_stream);
GGML_UNUSED(src1);
GGML_UNUSED(dst);
GGML_UNUSED(src1_dd);
GGML_UNUSED(ctx);
} }
inline void ggml_sycl_op_argmax(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, inline void ggml_sycl_op_argmax(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
const ggml_tensor *src1, ggml_tensor *dst,
const float *src0_dd, const float *src1_dd,
float *dst_dd,
const queue_ptr &main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_I32); GGML_ASSERT( dst->type == GGML_TYPE_I32);
const int64_t ncols = src0->ne[0]; dpct::queue_ptr main_stream = ctx.stream();
const int64_t nrows = ggml_nrows(src0); SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
int32_t * dst_dd = static_cast<int32_t *>(dst->data);
argmax_f32_i32_sycl(src0_dd, (int *)dst_dd, ncols, nrows, main_stream); const int64_t ncols = dst->src[0]->ne[0];
const int64_t nrows = ggml_nrows(dst->src[0]);
GGML_UNUSED(src1); argmax_f32_i32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream);
GGML_UNUSED(dst);
GGML_UNUSED(src1_dd);
GGML_UNUSED(ctx);
} }
inline void ggml_sycl_op_diag_mask_inf(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, inline void ggml_sycl_op_diag_mask_inf(ggml_backend_sycl_context & ctx,ggml_tensor *dst) {
const ggml_tensor *src1,
ggml_tensor *dst, const float *src0_dd,
const float *src1_dd, float *dst_dd,
const queue_ptr &main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
const int64_t ne00 = src0->ne[0]; const int64_t ne00 = dst->src[0]->ne[0];
const int64_t ne01 = src0->ne[1]; const int64_t ne01 = dst->src[0]->ne[1];
const int nrows0 = ggml_nrows(src0); const int nrows0 = ggml_nrows(dst->src[0]);
const int n_past = ((int32_t *) dst->op_params)[0]; const int n_past = ((int32_t *) dst->op_params)[0];
diag_mask_inf_f32_sycl(src0_dd, dst_dd, ne00, nrows0, ne01, n_past, main_stream); diag_mask_inf_f32_sycl(src0_dd, dst_dd, ne00, nrows0, ne01, n_past, main_stream);
GGML_UNUSED(src1);
GGML_UNUSED(dst);
GGML_UNUSED(src1_dd);
GGML_UNUSED(ctx);
} }
inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
ggml_tensor *dst, const float *src0_dd,
const float *src1_dd, float *dst_dd,
const queue_ptr &main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
float scale; float scale;
memcpy(&scale, dst->op_params, sizeof(float)); memcpy(&scale, dst->op_params, sizeof(float));
scale_f32_sycl(src0_dd, dst_dd, scale, ggml_nelements(src0), main_stream); scale_f32_sycl(src0_dd, dst_dd, scale, ggml_nelements(dst->src[0]), main_stream);
/* /*
DPCT1010:87: SYCL uses exceptions to report errors and does not use the DPCT1010:87: SYCL uses exceptions to report errors and does not use the
error codes. The call was replaced with 0. You need to rewrite this code. error codes. The call was replaced with 0. You need to rewrite this code.
*/ */
SYCL_CHECK(0); SYCL_CHECK(0);
GGML_UNUSED(src1);
GGML_UNUSED(dst);
GGML_UNUSED(src1_dd);
GGML_UNUSED(ctx);
} }
inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
ggml_tensor *dst, const float *src0_dd,
const float *src1_dd, float *dst_dd,
const queue_ptr &main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32);
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
float min; float min;
float max; float max;
memcpy(&min, dst->op_params, sizeof(float)); memcpy(&min, dst->op_params, sizeof(float));
memcpy(&max, (float *) dst->op_params + 1, sizeof(float)); memcpy(&max, (float *) dst->op_params + 1, sizeof(float));
clamp_f32_sycl(src0_dd, dst_dd, min, max, ggml_nelements(src0), main_stream); clamp_f32_sycl(src0_dd, dst_dd, min, max, ggml_nelements(dst->src[0]), ctx.stream());
/* /*
DPCT1010:88: SYCL uses exceptions to report errors and does not use the DPCT1010:88: SYCL uses exceptions to report errors and does not use the
error codes. The call was replaced with 0. You need to rewrite this code. error codes. The call was replaced with 0. You need to rewrite this code.
*/ */
SYCL_CHECK(0); SYCL_CHECK(0);
GGML_UNUSED(src1);
GGML_UNUSED(dst);
GGML_UNUSED(src1_dd);
GGML_UNUSED(ctx);
} }
static void ggml_sycl_set_peer_access(const int n_tokens, int main_device) { static void ggml_sycl_set_peer_access(const int n_tokens, int main_device) {
@ -2695,37 +2643,37 @@ catch (sycl::exception const &exc) {
static void ggml_sycl_repeat(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { static void ggml_sycl_repeat(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__); GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_repeat); ggml_sycl_op_repeat(ctx, dst);
GGML_SYCL_DEBUG("call %s done\n", __func__); GGML_SYCL_DEBUG("call %s done\n", __func__);
} }
static void ggml_sycl_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { static void ggml_sycl_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__); GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_get_rows); ggml_sycl_op_get_rows(ctx, dst);
GGML_SYCL_DEBUG("call %s done\n", __func__); GGML_SYCL_DEBUG("call %s done\n", __func__);
} }
static void ggml_sycl_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { static void ggml_sycl_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__); GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_norm); ggml_sycl_op_norm(ctx, dst);
GGML_SYCL_DEBUG("call %s done\n", __func__); GGML_SYCL_DEBUG("call %s done\n", __func__);
} }
static void ggml_sycl_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { static void ggml_sycl_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__); GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_rms_norm); ggml_sycl_op_rms_norm(ctx, dst);
GGML_SYCL_DEBUG("call %s done\n", __func__); GGML_SYCL_DEBUG("call %s done\n", __func__);
} }
static void ggml_sycl_l2_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { static void ggml_sycl_l2_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__); GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_l2_norm); ggml_sycl_op_l2_norm(ctx, dst);
GGML_SYCL_DEBUG("call %s done\n", __func__); GGML_SYCL_DEBUG("call %s done\n", __func__);
} }
static void ggml_sycl_group_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { static void ggml_sycl_group_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__); GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_group_norm); ggml_sycl_op_group_norm(ctx, dst);
GGML_SYCL_DEBUG("call %s done\n", __func__); GGML_SYCL_DEBUG("call %s done\n", __func__);
} }
@ -2881,14 +2829,10 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) { if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
// there is no broadcast and src0, src1 are contiguous across dims 2, 3 // there is no broadcast and src0, src1 are contiguous across dims 2, 3
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch( SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
*main_stream, oneapi::mkl::transpose::trans, *main_stream, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha, (const char *) src0_as_f16, dpct::library_data_t::real_half, nb01 / nb00, nb02 / nb00,
(const char *)src0_as_f16, dpct::library_data_t::real_half, (const char *) src1_f16, dpct::library_data_t::real_half, nb11 / nb10, nb12 / nb10, beta, (char *) dst_t,
nb01 / nb00, nb02 / nb00, cu_data_type, ne01, nb2 / nb0, ne12 * ne13, cu_compute_type)));
(const char *)src1_f16, dpct::library_data_t::real_half,
nb11 / nb10, nb12 / nb10, beta,
(char *)dst_t, cu_data_type, ne01, nb2 / nb0,
ne12 * ne13, cu_compute_type)));
} else { } else {
const int ne23 = ne12*ne13; const int ne23 = ne12*ne13;
@ -2923,7 +2867,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
}); });
} }
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch( SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
*main_stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha, *main_stream, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
(const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00, (const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00,
(const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, nb11 / nb10, beta, (const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, nb11 / nb10, beta,
(void **) (ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23, cu_compute_type, matrix_info.get()))); (void **) (ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23, cu_compute_type, matrix_info.get())));
@ -3269,48 +3213,48 @@ catch (sycl::exception const &exc) {
} }
static void ggml_sycl_scale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { static void ggml_sycl_scale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_scale); ggml_sycl_op_scale(ctx, dst);
} }
static void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { static void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_clamp); ggml_sycl_op_clamp(ctx, dst);
} }
static void ggml_sycl_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { static void ggml_sycl_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_diag_mask_inf); ggml_sycl_op_diag_mask_inf(ctx, dst);
} }
static void ggml_sycl_rope(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { static void ggml_sycl_rope(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_ASSERT(ggml_is_contiguous(dst->src[0])); // TODO: this restriction is temporary until non-cont support is implemented GGML_ASSERT(ggml_is_contiguous(dst->src[0])); // TODO: this restriction is temporary until non-cont support is implemented
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_rope); ggml_sycl_op_rope(ctx, dst);
} }
static void ggml_sycl_pool2d(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { static void ggml_sycl_pool2d(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_pool2d); ggml_sycl_op_pool2d(ctx, dst);
} }
static void ggml_sycl_im2col(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { static void ggml_sycl_im2col(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_im2col); ggml_sycl_op_im2col(ctx, dst);
} }
static void ggml_sycl_sum(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { static void ggml_sycl_sum(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_ASSERT(ggml_is_contiguous(dst->src[0])); GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_sum); ggml_sycl_op_sum(ctx, dst);
} }
static void ggml_sycl_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { static void ggml_sycl_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_ASSERT(ggml_is_contiguous(dst->src[0])); GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_sum_rows); ggml_sycl_op_sum_rows(ctx, dst);
} }
static void ggml_sycl_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { static void ggml_sycl_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_ASSERT(ggml_is_contiguous(dst->src[0])); GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_argsort); ggml_sycl_op_argsort(ctx, dst);
} }
static void ggml_sycl_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { static void ggml_sycl_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_ASSERT(ggml_is_contiguous(dst->src[0])); GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_argmax); ggml_sycl_op_argmax(ctx, dst);
} }
@ -3335,7 +3279,7 @@ catch (sycl::exception const &exc) {
std::exit(1); std::exit(1);
} }
static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tensor * dst) { static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tensor * dst) try {
if (!g_sycl_loaded) return false; if (!g_sycl_loaded) return false;
if (dst->src[0] != nullptr && ggml_backend_buffer_is_sycl_split(dst->src[0]->buffer)) { if (dst->src[0] != nullptr && ggml_backend_buffer_is_sycl_split(dst->src[0]->buffer)) {
@ -3528,6 +3472,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
} }
return true; return true;
} catch (sycl::exception & e) {
std::cerr << e.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
std::exit(1);
} }
GGML_API void ggml_backend_sycl_get_device_description(int device, char *description, GGML_API void ggml_backend_sycl_get_device_description(int device, char *description,

View file

@ -82,10 +82,9 @@ static void im2col_sycl(
} }
} }
void ggml_sycl_op_im2col( void ggml_sycl_op_im2col(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, const ggml_tensor * src0 = dst->src[0];
ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd, const ggml_tensor * src1 = dst->src[1];
const queue_ptr &main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F16); GGML_ASSERT(src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_F32); GGML_ASSERT(src1->type == GGML_TYPE_F32);
@ -115,12 +114,8 @@ void ggml_sycl_op_im2col(
const size_t batch_offset = src1->nb[3] / 4; // nb is byte offset, src is type float32 const size_t batch_offset = src1->nb[3] / 4; // nb is byte offset, src is type float32
if (dst->type == GGML_TYPE_F16) { if (dst->type == GGML_TYPE_F16) {
im2col_sycl(src1_dd, (sycl::half *)dst_dd, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, main_stream); im2col_sycl((const float *) src1->data, (sycl::half *)dst->data, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, ctx.stream());
} else { } else {
im2col_sycl(src1_dd, (float *)dst_dd, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, main_stream); im2col_sycl((const float *) src1->data, (float *)dst->data, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, ctx.stream());
} }
GGML_UNUSED(src0);
GGML_UNUSED(src0_dd);
GGML_UNUSED(ctx);
} }

View file

@ -16,8 +16,6 @@
#include "common.hpp" #include "common.hpp"
void ggml_sycl_op_im2col( void ggml_sycl_op_im2col(
ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_backend_sycl_context & ctx, ggml_tensor *dst);
ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd,
const queue_ptr &main_stream);
#endif // GGML_SYCL_IM2COL_HPP #endif // GGML_SYCL_IM2COL_HPP

View file

@ -367,7 +367,7 @@ static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols,
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
block_dims), block_dims),
[=](sycl::nd_item<3> item_ct1) [=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(WARP_SIZE)]] { [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
l2_norm_f32(x, dst, ncols, eps, item_ct1, l2_norm_f32(x, dst, ncols, eps, item_ct1,
nullptr, WARP_SIZE); nullptr, WARP_SIZE);
}); });
@ -389,7 +389,7 @@ static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols,
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
block_dims), block_dims),
[=](sycl::nd_item<3> item_ct1) [=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(WARP_SIZE)]] { [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
l2_norm_f32(x, dst, ncols, eps, item_ct1, l2_norm_f32(x, dst, ncols, eps, item_ct1,
get_pointer(s_sum_acc_ct1), work_group_size); get_pointer(s_sum_acc_ct1), work_group_size);
}); });
@ -397,90 +397,78 @@ static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols,
} }
} }
void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0, const ggml_tensor* src1, void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
ggml_tensor* dst, const float* src0_dd,
const float* src1_dd, float* dst_dd,
const queue_ptr& main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32);
const int64_t ne00 = src0->ne[0]; const int64_t ne00 = dst->src[0]->ne[0];
const int64_t nrows = ggml_nrows(src0); const int64_t nrows = ggml_nrows(dst->src[0]);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
float eps; float eps;
memcpy(&eps, dst->op_params, sizeof(float)); memcpy(&eps, dst->op_params, sizeof(float));
norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device); norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device);
(void)src1;
(void)dst;
(void)src1_dd;
} }
void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0, void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
const ggml_tensor* src1, ggml_tensor* dst,
const float* src0_dd, const float* src1_dd,
float* dst_dd,
const queue_ptr& main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32);
int num_groups = dst->op_params[0]; int num_groups = dst->op_params[0];
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
float eps; float eps;
memcpy(&eps, dst->op_params + 1, sizeof(float)); memcpy(&eps, dst->op_params + 1, sizeof(float));
int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups); int group_size = dst->src[0]->ne[0] * dst->src[0]->ne[1] * ((dst->src[0]->ne[2] + num_groups - 1) / num_groups);
group_norm_f32_sycl(src0_dd, dst_dd, num_groups, eps, group_size, src0->ne[0] * src0->ne[1] * src0->ne[2], main_stream, ctx.device); group_norm_f32_sycl(src0_dd, dst_dd, num_groups, eps, group_size, dst->src[0]->ne[0] * dst->src[0]->ne[1] * dst->src[0]->ne[2], main_stream, ctx.device);
(void)src1;
(void)dst;
(void)src1_dd;
GGML_UNUSED(ctx);
} }
void ggml_sycl_op_rms_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0, void ggml_sycl_op_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
const ggml_tensor* src1, ggml_tensor* dst,
const float* src0_dd, const float* src1_dd,
float* dst_dd,
const queue_ptr& main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32);
const int64_t ne00 = src0->ne[0]; const int64_t ne00 = dst->src[0]->ne[0];
const int64_t nrows = ggml_nrows(src0); const int64_t nrows = ggml_nrows(dst->src[0]);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
float eps; float eps;
memcpy(&eps, dst->op_params, sizeof(float)); memcpy(&eps, dst->op_params, sizeof(float));
rms_norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device); rms_norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device);
(void)src1;
(void)dst;
(void)src1_dd;
} }
void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0, void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
const ggml_tensor* src1, ggml_tensor* dst,
const float* src0_dd, const float* src1_dd,
float* dst_dd,
const queue_ptr& main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32);
const int64_t ne00 = src0->ne[0]; dpct::queue_ptr main_stream = ctx.stream();
const int64_t nrows = ggml_nrows(src0); SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const int64_t ne00 = dst->src[0]->ne[0];
const int64_t nrows = ggml_nrows(dst->src[0]);
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
float eps; float eps;
memcpy(&eps, dst->op_params, sizeof(float)); memcpy(&eps, dst->op_params, sizeof(float));
l2_norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device); l2_norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device);
(void)src1;
(void)dst;
(void)src1_dd;
} }

View file

@ -15,27 +15,12 @@
#include "common.hpp" #include "common.hpp"
void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0, const ggml_tensor* src1, void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst);
ggml_tensor* dst, const float* src0_dd,
const float* src1_dd, float* dst_dd,
const queue_ptr& main_stream);
void ggml_sycl_op_rms_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0, void ggml_sycl_op_rms_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst);
const ggml_tensor* src1, ggml_tensor* dst,
const float* src0_dd, const float* src1_dd,
float* dst_dd,
const queue_ptr& main_stream);
void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0, void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst);
const ggml_tensor* src1, ggml_tensor* dst,
const float* src0_dd, const float* src1_dd,
float* dst_dd,
const queue_ptr& main_stream);
void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0, void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst);
const ggml_tensor* src1, ggml_tensor* dst,
const float* src0_dd, const float* src1_dd,
float* dst_dd,
const queue_ptr& main_stream);
#endif // GGML_SYCL_NORM_HPP #endif // GGML_SYCL_NORM_HPP

View file

@ -1,8 +1,5 @@
#include <sycl/sycl.hpp>
#include <oneapi/mkl.hpp>
#include "outprod.hpp" #include "outprod.hpp"
void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, ggml_tensor* dst) { void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
const ggml_tensor *src0 = dst->src[0]; const ggml_tensor *src0 = dst->src[0];
const ggml_tensor *src1 = dst->src[1]; const ggml_tensor *src1 = dst->src[1];
@ -34,20 +31,13 @@ void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
// Handle transposition of src1 // Handle transposition of src1
const bool src1_T = ggml_is_transposed(src1); const bool src1_T = ggml_is_transposed(src1);
const oneapi::mkl::transpose src1_op = const oneapi::math::transpose src1_op = src1_T ? oneapi::math::transpose::nontrans : oneapi::math::transpose::trans;
src1_T ? oneapi::mkl::transpose::nontrans : oneapi::mkl::transpose::trans;
const int64_t ldb = (src1_T ? nb10 : nb11) / sizeof(float); const int64_t ldb = (src1_T ? nb10 : nb11) / sizeof(float);
try { try {
// Perform matrix multiplication using oneMKL GEMM // Perform matrix multiplication using oneMath GEMM
#ifdef GGML_SYCL_NVIDIA oneapi::math::blas::column_major::gemm(get_onemath_backend(*stream), oneapi::math::transpose::nontrans, src1_op,
oneapi::mkl::blas::column_major::gemm(oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ *stream }, ne0, ne1, ne01, alpha, src0_d, ne00, src1_d, ldb, beta, dst_d, ne0);
oneapi::mkl::transpose::nontrans, src1_op, ne0, ne1, ne01, alpha, src0_d,
ne00, src1_d, ldb, beta, dst_d, ne0);
#else
oneapi::mkl::blas::column_major::gemm(*stream, oneapi::mkl::transpose::nontrans, src1_op, ne0, ne1, ne01, alpha,
src0_d, ne00, src1_d, ldb, beta, dst_d, ne0);
#endif
} }
catch (sycl::exception const& exc) { catch (sycl::exception const& exc) {
std::cerr << exc.what() << std::endl; std::cerr << exc.what() << std::endl;

View file

@ -192,18 +192,15 @@ static void rope_neox_sycl(
} }
} }
void ggml_sycl_op_rope( void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
const float *src0_dd, const float *src1_dd, float *dst_dd, const queue_ptr &main_stream) {
const ggml_tensor * src2 = dst->src[2];
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
GGML_ASSERT(src0->type == dst->type); GGML_ASSERT(dst->src[0]->type == dst->type);
const int64_t ne00 = src0->ne[0]; const int64_t ne00 = dst->src[0]->ne[0];
const int64_t ne01 = src0->ne[1]; const int64_t ne01 = dst->src[0]->ne[1];
const int64_t nr = ggml_nrows(src0); const int64_t nr = ggml_nrows(dst->src[0]);
//const int n_past = ((int32_t *) dst->op_params)[0]; //const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1]; const int n_dims = ((int32_t *) dst->op_params)[1];
@ -228,49 +225,47 @@ void ggml_sycl_op_rope(
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
const int32_t * pos = (const int32_t *) src1_dd; const int32_t * pos = (const int32_t *) dst->src[1]->data;
const float * freq_factors = nullptr; const float * freq_factors = nullptr;
if (src2 != nullptr) { if (dst->src[2] != nullptr) {
freq_factors = (const float *) src2->data; freq_factors = (const float *) dst->src[2]->data;
} }
rope_corr_dims corr_dims; rope_corr_dims corr_dims;
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v); ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
// compute // compute
if (is_neox) { if (is_neox) {
if (src0->type == GGML_TYPE_F32) { if (dst->src[0]->type == GGML_TYPE_F32) {
rope_neox_sycl( rope_neox_sycl(
(const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, (const float *)dst->src[0]->data, (float *)dst->data, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
attn_factor, corr_dims, freq_factors, main_stream attn_factor, corr_dims, freq_factors, main_stream
); );
} else if (src0->type == GGML_TYPE_F16) { } else if (dst->src[0]->type == GGML_TYPE_F16) {
rope_neox_sycl( rope_neox_sycl(
(const sycl::half *)src0_dd, (sycl::half *)dst_dd, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, (const sycl::half *)dst->src[0]->data, (sycl::half *)dst->data, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
attn_factor, corr_dims, freq_factors, main_stream attn_factor, corr_dims, freq_factors, main_stream
); );
} else { } else {
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");
} }
} else { } else {
if (src0->type == GGML_TYPE_F32) { if (dst->src[0]->type == GGML_TYPE_F32) {
rope_norm_sycl( rope_norm_sycl(
(const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, (const float *)dst->src[0]->data, (float *)dst->data, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
attn_factor, corr_dims, freq_factors, main_stream attn_factor, corr_dims, freq_factors, main_stream
); );
} else if (src0->type == GGML_TYPE_F16) { } else if (dst->src[0]->type == GGML_TYPE_F16) {
rope_norm_sycl( rope_norm_sycl(
(const sycl::half *)src0_dd, (sycl::half *)dst_dd, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, (const sycl::half *)dst->src[0]->data, (sycl::half *)dst->data, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
attn_factor, corr_dims, freq_factors, main_stream attn_factor, corr_dims, freq_factors, main_stream
); );
} else { } else {
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");
} }
} }
GGML_UNUSED(src1);
GGML_UNUSED(dst);
GGML_UNUSED(src1_dd);
GGML_UNUSED(ctx);
} }

View file

@ -15,8 +15,6 @@
#include "common.hpp" #include "common.hpp"
void ggml_sycl_op_rope( void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst);
ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
const float *src0_dd, const float *src1_dd, float *dst_dd, const queue_ptr &main_stream);
#endif // GGML_SYCL_ROPE_HPP #endif // GGML_SYCL_ROPE_HPP

View file

@ -1,6 +1,9 @@
#ifndef GGML_VULKAN_COOPMAT_GLSLC_SUPPORT #ifndef GGML_VULKAN_COOPMAT_GLSLC_SUPPORT
#define GGML_VULKAN_COOPMAT_GLSLC_SUPPORT #define GGML_VULKAN_COOPMAT_GLSLC_SUPPORT
#endif #endif
#ifndef GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT
#define GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT
#endif
#include "ggml-vulkan.h" #include "ggml-vulkan.h"
#include <vulkan/vulkan_core.h> #include <vulkan/vulkan_core.h>
@ -238,6 +241,8 @@ struct vk_device_struct {
bool float_controls_rte_fp16; bool float_controls_rte_fp16;
bool subgroup_add; bool subgroup_add;
bool integer_dot_product;
bool subgroup_size_control; bool subgroup_size_control;
uint32_t subgroup_min_size; uint32_t subgroup_min_size;
uint32_t subgroup_max_size; uint32_t subgroup_max_size;
@ -249,6 +254,12 @@ struct vk_device_struct {
uint32_t coopmat_m; uint32_t coopmat_m;
uint32_t coopmat_n; uint32_t coopmat_n;
uint32_t coopmat_k; uint32_t coopmat_k;
bool coopmat_int_support;
uint32_t coopmat_int_m;
uint32_t coopmat_int_n;
uint32_t coopmat_int_k;
bool coopmat2; bool coopmat2;
size_t idx; size_t idx;
@ -267,10 +278,10 @@ struct vk_device_struct {
vk_matmul_pipeline pipeline_matmul_f32_f16 {}; vk_matmul_pipeline pipeline_matmul_f32_f16 {};
vk_matmul_pipeline2 pipeline_matmul_f16; vk_matmul_pipeline2 pipeline_matmul_f16;
vk_matmul_pipeline2 pipeline_matmul_f16_f32; vk_matmul_pipeline2 pipeline_matmul_f16_f32;
vk_pipeline pipeline_matmul_split_k_reduce;
vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_COUNT];
vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat[GGML_TYPE_COUNT]; vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat[GGML_TYPE_COUNT];
vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_COUNT];
vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_COUNT];
vk_matmul_pipeline pipeline_matmul_id_f32 {}; vk_matmul_pipeline pipeline_matmul_id_f32 {};
vk_matmul_pipeline2 pipeline_matmul_id_f16; vk_matmul_pipeline2 pipeline_matmul_id_f16;
@ -278,6 +289,9 @@ struct vk_device_struct {
vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_id[GGML_TYPE_COUNT]; vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_id[GGML_TYPE_COUNT];
vk_pipeline pipeline_matmul_split_k_reduce;
vk_pipeline pipeline_quantize_q8_1;
vk_pipeline pipeline_dequant[GGML_TYPE_COUNT]; vk_pipeline pipeline_dequant[GGML_TYPE_COUNT];
vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols]; vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols];
vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols]; vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols];
@ -644,6 +658,13 @@ struct vk_op_rwkv_wkv7_push_constants {
uint32_t H; uint32_t H;
}; };
struct vk_op_upscale_push_constants {
uint32_t ne; uint32_t a_offset; uint32_t d_offset;
uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13;
float sf0; float sf1; float sf2; float sf3;
};
// Allow pre-recording command buffers // Allow pre-recording command buffers
struct vk_staging_memcpy { struct vk_staging_memcpy {
vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {} vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
@ -653,13 +674,6 @@ struct vk_staging_memcpy {
size_t n; size_t n;
}; };
struct vk_op_upscale_push_constants {
uint32_t ne; uint32_t a_offset; uint32_t d_offset;
uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13;
float sf0; float sf1; float sf2; float sf3;
};
struct vk_context_struct { struct vk_context_struct {
vk_submission * s; vk_submission * s;
std::vector<vk_sequence> seqs; std::vector<vk_sequence> seqs;
@ -1602,6 +1616,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
// mulmat // mulmat
std::vector<uint32_t> l_warptile, m_warptile, s_warptile, std::vector<uint32_t> l_warptile, m_warptile, s_warptile,
l_warptile_mmq, m_warptile_mmq, s_warptile_mmq, l_warptile_mmq, m_warptile_mmq, s_warptile_mmq,
l_warptile_mmq_int, m_warptile_mmq_int, s_warptile_mmq_int,
l_warptile_mmq_k, m_warptile_mmq_k, s_warptile_mmq_k, l_warptile_mmq_k, m_warptile_mmq_k, s_warptile_mmq_k,
l_warptile_mmqid, m_warptile_mmqid, s_warptile_mmqid; l_warptile_mmqid, m_warptile_mmqid, s_warptile_mmqid;
std::array<uint32_t, 3> l_wg_denoms, m_wg_denoms, s_wg_denoms, std::array<uint32_t, 3> l_wg_denoms, m_wg_denoms, s_wg_denoms,
@ -1666,6 +1681,20 @@ static void ggml_vk_load_shaders(vk_device& device) {
m_warptile_mmq = { 128, 64, 64, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 }; m_warptile_mmq = { 128, 64, 64, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
s_warptile_mmq = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 }; s_warptile_mmq = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 };
const uint32_t tm_int_l = device->coopmat_int_support ? device->coopmat_int_m : 4;
const uint32_t tm_int_m = device->coopmat_int_support ? device->coopmat_int_m : 4;
const uint32_t tm_int_s = device->coopmat_int_support ? device->coopmat_int_m : 2;
const uint32_t tn_int_l = device->coopmat_int_support ? device->coopmat_int_n : 4;
const uint32_t tn_int_m = device->coopmat_int_support ? device->coopmat_int_n : 2;
const uint32_t tn_int_s = device->coopmat_int_support ? device->coopmat_int_n : 2;
const uint32_t tk_int_l = device->coopmat_int_support ? device->coopmat_int_k : 1;
const uint32_t tk_int_m = device->coopmat_int_support ? device->coopmat_int_k : 1;
const uint32_t tk_int_s = device->coopmat_int_support ? device->coopmat_int_k : 1;
l_warptile_mmq_int = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, tm_int_l, tn_int_l, tk_int_l, subgroup_size_8 };
m_warptile_mmq_int = { 128, 64, 64, 32, subgroup_size_8, 32, 2, tm_int_m, tn_int_m, tk_int_m, subgroup_size_8 };
s_warptile_mmq_int = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_int_s, tn_int_s, tk_int_s, subgroup_size_8 };
l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 }; l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 };
m_mmq_wg_denoms = m_wg_denoms = { 64, 64, 1 }; m_mmq_wg_denoms = m_wg_denoms = { 64, 64, 1 };
s_mmq_wg_denoms = s_wg_denoms = { 32, 32, 1 }; s_mmq_wg_denoms = s_wg_denoms = { 32, 32, 1 };
@ -2004,6 +2033,14 @@ static void ggml_vk_load_shaders(vk_device& device) {
if (device->mul_mat ## ID ## _s[TYPE]) \ if (device->mul_mat ## ID ## _s[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
if (device->mul_mat ## ID ## _l[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
if (device->mul_mat ## ID ## _m[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
if (device->mul_mat ## ID ## _s[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
// Create 2 variants, {f16,f32} accumulator // Create 2 variants, {f16,f32} accumulator
#define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ #define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
CREATE_MM(TYPE, PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ CREATE_MM(TYPE, PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
@ -2035,6 +2072,16 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
if (device->integer_dot_product) {
CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
}
#endif
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
@ -2060,6 +2107,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
#undef CREATE_MM2 #undef CREATE_MM2
#undef CREATE_MMQ
#undef CREATE_MM #undef CREATE_MM
} else { } else {
// Create 6 variants, {s,m,l}x{unaligned,aligned} // Create 6 variants, {s,m,l}x{unaligned,aligned}
@ -2077,6 +2125,14 @@ static void ggml_vk_load_shaders(vk_device& device) {
if (device->mul_mat ## ID ## _s[TYPE]) \ if (device->mul_mat ## ID ## _s[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
if (device->mul_mat ## ID ## _l[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
if (device->mul_mat ## ID ## _m[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
if (device->mul_mat ## ID ## _s[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
@ -2103,6 +2159,16 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
if (device->integer_dot_product) {
CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
}
#endif
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
@ -2136,7 +2202,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
uint32_t rm_stdq = 1; uint32_t rm_stdq = 1;
uint32_t rm_kq = 2; uint32_t rm_kq = 2;
if (device->vendor_id == VK_VENDOR_ID_AMD) { if (device->vendor_id == VK_VENDOR_ID_AMD) {
if (device->subgroup_min_size == 64 && device->subgroup_max_size == 64) { // GCN if (device->architecture == AMD_GCN) {
rm_stdq = 2; rm_stdq = 2;
rm_kq = 4; rm_kq = 4;
} }
@ -2270,6 +2336,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1", quantize_q8_1_len, quantize_q8_1_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1);
for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) { for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) {
if (device->subgroup_add && device->subgroup_require_full_support) { if (device->subgroup_add && device->subgroup_require_full_support) {
@ -2456,6 +2523,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
bool pipeline_robustness = false; bool pipeline_robustness = false;
bool coopmat2_support = false; bool coopmat2_support = false;
device->coopmat_support = false; device->coopmat_support = false;
device->integer_dot_product = false;
for (const auto& properties : ext_props) { for (const auto& properties : ext_props) {
if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) { if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
@ -2481,6 +2549,11 @@ static vk_device ggml_vk_get_device(size_t idx) {
} else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 && } else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 &&
!getenv("GGML_VK_DISABLE_COOPMAT2")) { !getenv("GGML_VK_DISABLE_COOPMAT2")) {
coopmat2_support = true; coopmat2_support = true;
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
} else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0 &&
!getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) {
device->integer_dot_product = true;
#endif
} }
} }
@ -2494,6 +2567,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
vk::PhysicalDeviceVulkan11Properties vk11_props; vk::PhysicalDeviceVulkan11Properties vk11_props;
vk::PhysicalDeviceVulkan12Properties vk12_props; vk::PhysicalDeviceVulkan12Properties vk12_props;
vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props; vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props;
vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR shader_integer_dot_product_props;
props2.pNext = &props3; props2.pNext = &props3;
props3.pNext = &subgroup_props; props3.pNext = &subgroup_props;
@ -2528,6 +2602,15 @@ static vk_device ggml_vk_get_device(size_t idx) {
} }
#endif #endif
#if !defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) //prevent crash if we do a non-d4pa build
device->integer_dot_product = false;
#endif
if (device->integer_dot_product) {
last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_props;
last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_props;
}
device->physical_device.getProperties2(&props2); device->physical_device.getProperties2(&props2);
device->properties = props2.properties; device->properties = props2.properties;
device->vendor_id = device->properties.vendorID; device->vendor_id = device->properties.vendorID;
@ -2578,6 +2661,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
device->coopmat_support = false; device->coopmat_support = false;
#endif #endif
device->integer_dot_product = device->integer_dot_product && shader_integer_dot_product_props.integerDotProduct4x8BitPackedSignedAccelerated;
std::vector<vk::QueueFamilyProperties> queue_family_props = device->physical_device.getQueueFamilyProperties(); std::vector<vk::QueueFamilyProperties> queue_family_props = device->physical_device.getQueueFamilyProperties();
// Try to find a non-graphics compute queue and transfer-focused queues // Try to find a non-graphics compute queue and transfer-focused queues
@ -2670,6 +2755,14 @@ static vk_device ggml_vk_get_device(size_t idx) {
device_extensions.push_back("VK_KHR_maintenance4"); device_extensions.push_back("VK_KHR_maintenance4");
} }
VkPhysicalDeviceShaderIntegerDotProductFeaturesKHR shader_integer_dot_product_features {};
shader_integer_dot_product_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_INTEGER_DOT_PRODUCT_FEATURES_KHR;
if (device->integer_dot_product) {
last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_features;
last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_features;
device_extensions.push_back("VK_KHR_shader_integer_dot_product");
}
vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2); vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2);
device->fp16 = device->fp16 && vk12_features.shaderFloat16; device->fp16 = device->fp16 && vk12_features.shaderFloat16;
@ -2839,6 +2932,17 @@ static vk_device ggml_vk_get_device(size_t idx) {
device->coopmat_acc_f16_support = true; device->coopmat_acc_f16_support = true;
} }
} }
} else if ((vk::ComponentTypeKHR)prop.AType == vk::ComponentTypeKHR::eSint8 &&
(vk::ComponentTypeKHR)prop.BType == vk::ComponentTypeKHR::eSint8 &&
(vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eSint32 &&
(vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eSint32 &&
(vk::ScopeKHR)prop.scope == vk::ScopeKHR::eSubgroup &&
device->coopmat_int_m == 0
) {
device->coopmat_int_support = true;
device->coopmat_int_m = prop.MSize;
device->coopmat_int_n = prop.NSize;
device->coopmat_int_k = prop.KSize;
} }
} }
@ -2943,25 +3047,11 @@ static void ggml_vk_print_gpu_info(size_t idx) {
vk::PhysicalDevice physical_device = devices[dev_num]; vk::PhysicalDevice physical_device = devices[dev_num];
std::vector<vk::ExtensionProperties> ext_props = physical_device.enumerateDeviceExtensionProperties(); std::vector<vk::ExtensionProperties> ext_props = physical_device.enumerateDeviceExtensionProperties();
vk::PhysicalDeviceProperties2 props2;
vk::PhysicalDeviceMaintenance3Properties props3;
vk::PhysicalDeviceSubgroupProperties subgroup_props;
vk::PhysicalDeviceDriverProperties driver_props;
props2.pNext = &props3;
props3.pNext = &subgroup_props;
subgroup_props.pNext = &driver_props;
physical_device.getProperties2(&props2);
vk_device_architecture arch = get_device_architecture(physical_device);
uint32_t default_subgroup_size = get_subgroup_size("", arch);
const size_t subgroup_size = (default_subgroup_size != 0) ? default_subgroup_size : subgroup_props.subgroupSize;
const bool uma = props2.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
bool fp16_storage = false; bool fp16_storage = false;
bool fp16_compute = false; bool fp16_compute = false;
bool coopmat_support = false; bool coopmat_support = false;
bool coopmat2_support = false; bool coopmat2_support = false;
bool integer_dot_product = false;
for (auto properties : ext_props) { for (auto properties : ext_props) {
if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) { if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) {
@ -2977,27 +3067,44 @@ static void ggml_vk_print_gpu_info(size_t idx) {
} else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 && } else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 &&
!getenv("GGML_VK_DISABLE_COOPMAT2")) { !getenv("GGML_VK_DISABLE_COOPMAT2")) {
coopmat2_support = true; coopmat2_support = true;
#endif
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
} else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0 &&
!getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) {
integer_dot_product = true;
#endif #endif
} }
} }
const vk_device_architecture device_architecture = get_device_architecture(physical_device); const vk_device_architecture device_architecture = get_device_architecture(physical_device);
if (!ggml_vk_khr_cooperative_matrix_support(props2.properties, driver_props, device_architecture)) {
coopmat_support = false;
}
const char* GGML_VK_DISABLE_F16 = getenv("GGML_VK_DISABLE_F16"); const char* GGML_VK_DISABLE_F16 = getenv("GGML_VK_DISABLE_F16");
bool force_disable_f16 = GGML_VK_DISABLE_F16 != nullptr; bool force_disable_f16 = GGML_VK_DISABLE_F16 != nullptr;
bool fp16 = !force_disable_f16 && fp16_storage && fp16_compute; bool fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
vk::PhysicalDeviceFeatures device_features = physical_device.getFeatures(); vk::PhysicalDeviceProperties2 props2;
vk::PhysicalDeviceMaintenance3Properties props3;
vk::PhysicalDeviceSubgroupProperties subgroup_props;
vk::PhysicalDeviceDriverProperties driver_props;
vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR shader_integer_dot_product_props;
props2.pNext = &props3;
props3.pNext = &subgroup_props;
subgroup_props.pNext = &driver_props;
// Pointer to the last chain element
VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&driver_props;
if (integer_dot_product) {
last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_props;
last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_props;
}
physical_device.getProperties2(&props2);
VkPhysicalDeviceFeatures2 device_features2; VkPhysicalDeviceFeatures2 device_features2;
device_features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2; device_features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2;
device_features2.pNext = nullptr; device_features2.pNext = nullptr;
device_features2.features = (VkPhysicalDeviceFeatures)device_features;
VkPhysicalDeviceVulkan11Features vk11_features; VkPhysicalDeviceVulkan11Features vk11_features;
vk11_features.pNext = nullptr; vk11_features.pNext = nullptr;
@ -3010,7 +3117,7 @@ static void ggml_vk_print_gpu_info(size_t idx) {
vk11_features.pNext = &vk12_features; vk11_features.pNext = &vk12_features;
// Pointer to the last chain element // Pointer to the last chain element
VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&vk12_features; last_struct = (VkBaseOutStructure *)&vk12_features;
#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) #if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
VkPhysicalDeviceCooperativeMatrixFeaturesKHR coopmat_features; VkPhysicalDeviceCooperativeMatrixFeaturesKHR coopmat_features;
@ -3022,20 +3129,37 @@ static void ggml_vk_print_gpu_info(size_t idx) {
last_struct->pNext = (VkBaseOutStructure *)&coopmat_features; last_struct->pNext = (VkBaseOutStructure *)&coopmat_features;
last_struct = (VkBaseOutStructure *)&coopmat_features; last_struct = (VkBaseOutStructure *)&coopmat_features;
} }
#endif
VkPhysicalDeviceShaderIntegerDotProductFeaturesKHR shader_integer_dot_product_features {};
shader_integer_dot_product_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_INTEGER_DOT_PRODUCT_FEATURES_KHR;
if (integer_dot_product) {
last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_features;
last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_features;
}
vkGetPhysicalDeviceFeatures2(physical_device, &device_features2); vkGetPhysicalDeviceFeatures2(physical_device, &device_features2);
fp16 = fp16 && vk12_features.shaderFloat16; fp16 = fp16 && vk12_features.shaderFloat16;
coopmat_support = coopmat_support && coopmat_features.cooperativeMatrix; uint32_t default_subgroup_size = get_subgroup_size("", device_architecture);
#endif const size_t subgroup_size = (default_subgroup_size != 0) ? default_subgroup_size : subgroup_props.subgroupSize;
const bool uma = props2.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
integer_dot_product = integer_dot_product
&& shader_integer_dot_product_props.integerDotProduct4x8BitPackedSignedAccelerated
&& shader_integer_dot_product_features.shaderIntegerDotProduct;
coopmat_support = coopmat_support
&& coopmat_features.cooperativeMatrix
&& ggml_vk_khr_cooperative_matrix_support(props2.properties, driver_props, device_architecture);
std::string matrix_cores = coopmat2_support ? "NV_coopmat2" : coopmat_support ? "KHR_coopmat" : "none"; std::string matrix_cores = coopmat2_support ? "NV_coopmat2" : coopmat_support ? "KHR_coopmat" : "none";
std::string device_name = props2.properties.deviceName.data(); std::string device_name = props2.properties.deviceName.data();
GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | warp size: %zu | shared memory: %d | matrix cores: %s\n", GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | warp size: %zu | shared memory: %d | int dot: %d | matrix cores: %s\n",
idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, subgroup_size, idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, subgroup_size,
props2.properties.limits.maxComputeSharedMemorySize, matrix_cores.c_str()); props2.properties.limits.maxComputeSharedMemorySize, integer_dot_product, matrix_cores.c_str());
if (props2.properties.deviceType == vk::PhysicalDeviceType::eCpu) { if (props2.properties.deviceType == vk::PhysicalDeviceType::eCpu) {
GGML_LOG_DEBUG("ggml_vulkan: Warning: Device type is CPU. This is probably not the device you want.\n"); GGML_LOG_DEBUG("ggml_vulkan: Warning: Device type is CPU. This is probably not the device you want.\n");
@ -3301,6 +3425,17 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
} }
} }
// MMQ
if (src1_type == GGML_TYPE_Q8_1) {
vk_matmul_pipeline pipelines = ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f16acc;
if (pipelines->s == nullptr && pipelines->m == nullptr && pipelines->l == nullptr) {
return nullptr;
}
return pipelines;
}
if (src1_type != GGML_TYPE_F32 && !ctx->device->coopmat2) { if (src1_type != GGML_TYPE_F32 && !ctx->device->coopmat2) {
return nullptr; return nullptr;
} }
@ -3593,8 +3728,6 @@ static vk_submission ggml_vk_begin_submission(vk_device& device, vk_queue& q, bo
return s; return s;
} }
static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context& subctx, vk_pipeline& pipeline, std::initializer_list<vk::DescriptorBufferInfo> const& descriptor_buffer_infos, size_t push_constant_size, const void* push_constants, std::array<uint32_t, 3> elements) { static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context& subctx, vk_pipeline& pipeline, std::initializer_list<vk::DescriptorBufferInfo> const& descriptor_buffer_infos, size_t push_constant_size, const void* push_constants, std::array<uint32_t, 3> elements) {
const uint32_t wg0 = CEIL_DIV(elements[0], pipeline->wg_denoms[0]); const uint32_t wg0 = CEIL_DIV(elements[0], pipeline->wg_denoms[0]);
const uint32_t wg1 = CEIL_DIV(elements[1], pipeline->wg_denoms[1]); const uint32_t wg1 = CEIL_DIV(elements[1], pipeline->wg_denoms[1]);
@ -4024,8 +4157,8 @@ static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, int m, int
return split_k; return split_k;
} }
static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned, ggml_type src0_type) { static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type, ggml_type src1_type) {
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")"); VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")");
if (ctx->device->coopmat2) { if (ctx->device->coopmat2) {
// Use large shader when the N dimension is greater than the medium shader's tile size // Use large shader when the N dimension is greater than the medium shader's tile size
@ -4050,9 +4183,9 @@ static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx,
return aligned ? mmp->a_l : mmp->l; return aligned ? mmp->a_l : mmp->l;
} }
static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type) { static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type, ggml_type src1_type) {
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ")"); VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")");
return ggml_vk_guess_matmul_pipeline(ctx, mmp, m, n, true, src0_type)->align; return ggml_vk_guess_matmul_pipeline(ctx, mmp, m, n, true, src0_type, src1_type)->align;
} }
static void ggml_vk_matmul( static void ggml_vk_matmul(
@ -4062,7 +4195,7 @@ static void ggml_vk_matmul(
uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d, uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d,
uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3, uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3,
uint32_t padded_n) { uint32_t padded_n) {
VK_LOG_DEBUG("ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ")"); VK_LOG_DEBUG("ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ", padded_n: " << padded_n << ")");
ggml_vk_sync_buffers(subctx); ggml_vk_sync_buffers(subctx);
if (split_k == 1) { if (split_k == 1) {
const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3, padded_n }; const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3, padded_n };
@ -4080,7 +4213,7 @@ static void ggml_vk_matmul(
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2.size() * sizeof(uint32_t), pc2.data(), { m * n * batch, 1, 1 }); ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2.size() * sizeof(uint32_t), pc2.data(), { m * n * batch, 1, 1 });
} }
static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned, ggml_type src0_type) { static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type) {
VK_LOG_DEBUG("ggml_vk_guess_matmul_id_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")"); VK_LOG_DEBUG("ggml_vk_guess_matmul_id_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")");
if (ctx->device->coopmat2) { if (ctx->device->coopmat2) {
@ -4222,6 +4355,25 @@ static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context&
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, sizeof(vk_op_unary_push_constants), &pc, elements); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, sizeof(vk_op_unary_push_constants), &pc, elements);
} }
static vk_pipeline ggml_vk_get_quantize_pipeline(ggml_backend_vk_context * ctx, ggml_type type) {
switch(type) {
case GGML_TYPE_Q8_1:
return ctx->device->pipeline_quantize_q8_1;
default:
std::cerr << "Missing quantize pipeline for type: " << ggml_type_name(type) << std::endl;
GGML_ABORT("fatal error");
}
}
static void ggml_vk_quantize_q8_1(ggml_backend_vk_context * ctx, vk_context& subctx, vk_subbuffer&& in, vk_subbuffer&& out, uint32_t ne) {
VK_LOG_DEBUG("ggml_vk_quantize_q8_1(" << "buffer in size=" << in.buffer->size << ", buffer out size=" << out.buffer->size << ", " << ne << ")");
vk_pipeline pipeline = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, sizeof(uint32_t), &ne, { ne, 1, 1 });
}
static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
VK_LOG_DEBUG("ggml_vk_mul_mat_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; VK_LOG_DEBUG("ggml_vk_mul_mat_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
@ -4273,10 +4425,19 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig; const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type, (ggml_prec)dst->op_params[0]); bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && (ne11 * ne10) % 4 == 0;
// Check for mmq first
vk_matmul_pipeline mmp = quantize_y ? ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, GGML_TYPE_Q8_1, (ggml_prec)dst->op_params[0]) : nullptr;
if (mmp == nullptr) {
// Fall back to f16 dequant mul mat
mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type, (ggml_prec)dst->op_params[0]);
quantize_y = false;
}
const bool qx_needs_dequant = mmp == nullptr || x_non_contig; const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig; const bool qy_needs_dequant = !quantize_y && ((src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig);
if (qx_needs_dequant) { if (qx_needs_dequant) {
// Fall back to dequant + f16 mulmat // Fall back to dequant + f16 mulmat
@ -4286,13 +4447,13 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
// Not implemented // Not implemented
GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? GGML_TYPE_F16 : src0->type)); const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? GGML_TYPE_F16 : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type)));
const bool aligned = ne10 == kpad && ne01 > 8 && ne11 > 8; const bool aligned = !quantize_y && ne10 == kpad && ne01 > 8 && ne11 > 8;
vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type); vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type));
// Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking // Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking
uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) :ne11; uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) : ne11;
const int x_ne = ne01 * ne00; const int x_ne = ne01 * ne00;
const int y_ne = padded_n * ne10; const int y_ne = padded_n * ne10;
const int d_ne = ne11 * ne01; const int d_ne = ne11 * ne01;
@ -4302,11 +4463,12 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type); const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type); const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne; const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne;
const uint64_t y_sz = y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne; const uint64_t y_sz = quantize_y ? (y_ne * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) : (y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne);
const uint64_t d_sz = sizeof(float) * d_ne; const uint64_t d_sz = sizeof(float) * d_ne;
vk_pipeline to_fp16_vk_0 = nullptr; vk_pipeline to_fp16_vk_0 = nullptr;
vk_pipeline to_fp16_vk_1 = nullptr; vk_pipeline to_fp16_vk_1 = nullptr;
vk_pipeline to_q8_1 = nullptr;
if (x_non_contig) { if (x_non_contig) {
to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, GGML_TYPE_F16); to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, GGML_TYPE_F16);
@ -4321,6 +4483,10 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT
GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
if (quantize_y) {
to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
}
if (dryrun) { if (dryrun) {
const uint64_t x_sz_upd = x_sz * ne02 * ne03; const uint64_t x_sz_upd = x_sz * ne02 * ne03;
const uint64_t y_sz_upd = y_sz * ne12 * ne13; const uint64_t y_sz_upd = y_sz * ne12 * ne13;
@ -4334,7 +4500,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) { if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) {
ctx->prealloc_size_x = x_sz_upd; ctx->prealloc_size_x = x_sz_upd;
} }
if (qy_needs_dequant && ctx->prealloc_size_y < y_sz_upd) { if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz_upd) {
ctx->prealloc_size_y = y_sz_upd; ctx->prealloc_size_y = y_sz_upd;
} }
if (split_k > 1 && ctx->prealloc_size_split_k < split_k_size) { if (split_k > 1 && ctx->prealloc_size_split_k < split_k_size) {
@ -4349,6 +4515,9 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
if (qy_needs_dequant) { if (qy_needs_dequant) {
ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_1, 1); ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_1, 1);
} }
if (quantize_y) {
ggml_pipeline_request_descriptor_sets(ctx->device, to_q8_1, 1);
}
if (split_k > 1) { if (split_k > 1) {
ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_matmul_split_k_reduce, 1); ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_matmul_split_k_reduce, 1);
} }
@ -4384,6 +4553,9 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
if (qy_needs_dequant) { if (qy_needs_dequant) {
d_Y = ctx->prealloc_y; d_Y = ctx->prealloc_y;
GGML_ASSERT(d_Y->size >= y_sz * ne12 * ne13); GGML_ASSERT(d_Y->size >= y_sz * ne12 * ne13);
} else if (quantize_y) {
d_Y = ctx->prealloc_y;
GGML_ASSERT(d_Y->size >= y_ne * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1));
} else { } else {
d_Y = d_Qy; d_Y = d_Qy;
y_buf_offset = qy_buf_offset; y_buf_offset = qy_buf_offset;
@ -4400,6 +4572,9 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
if (y_non_contig) { if (y_non_contig) {
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }); ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
} }
if (quantize_y) {
ggml_vk_quantize_q8_1(ctx, subctx, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }, y_ne * ne12 * ne13);
}
uint32_t stride_batch_x = ne00*ne01; uint32_t stride_batch_x = ne00*ne01;
uint32_t stride_batch_y = ne10*ne11; uint32_t stride_batch_y = ne10*ne11;
@ -4408,7 +4583,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
stride_batch_x = src0->nb[0] / ggml_type_size(src0->type); stride_batch_x = src0->nb[0] / ggml_type_size(src0->type);
} }
if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) { if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant && !quantize_y) {
stride_batch_y = src1->nb[0] / ggml_type_size(src1->type); stride_batch_y = src1->nb[0] / ggml_type_size(src1->type);
} }
@ -6937,6 +7112,10 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
} }
} }
if (ctx->device->need_compiles) {
ggml_vk_load_shaders(ctx->device);
}
ggml_pipeline_allocate_descriptor_sets(ctx->device); ggml_pipeline_allocate_descriptor_sets(ctx->device);
vk_buffer d_X = ggml_vk_create_buffer_check(ctx->device, sizeof(X_TYPE) * x_ne, vk::MemoryPropertyFlagBits::eDeviceLocal); vk_buffer d_X = ggml_vk_create_buffer_check(ctx->device, sizeof(X_TYPE) * x_ne, vk::MemoryPropertyFlagBits::eDeviceLocal);
@ -7185,6 +7364,10 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_
ggml_pipeline_request_descriptor_sets(ctx->device, p, 1); ggml_pipeline_request_descriptor_sets(ctx->device, p, 1);
if (ctx->device->need_compiles) {
ggml_vk_load_shaders(ctx->device);
}
ggml_pipeline_allocate_descriptor_sets(ctx->device); ggml_pipeline_allocate_descriptor_sets(ctx->device);
ggml_vk_buffer_write(qx_buf, 0, qx, qx_sz); ggml_vk_buffer_write(qx_buf, 0, qx, qx_sz);
@ -7244,66 +7427,198 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_
free(x_chk); free(x_chk);
} }
static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, size_t n, size_t k, size_t batch, size_t num_it, size_t split_k, size_t shader_size, ggml_type quant) { // This does not work without ggml q8_1 quantization support
//
// typedef uint16_t ggml_half;
// typedef uint32_t ggml_half2;
//
// #define QK8_1 32
// typedef struct {
// union {
// struct {
// ggml_half d; // delta
// ggml_half s; // d * sum(qs[i])
// } GGML_COMMON_AGGR_S;
// ggml_half2 ds;
// } GGML_COMMON_AGGR_U;
// int8_t qs[QK8_1]; // quants
// } block_q8_1;
//
// static void ggml_vk_test_quantize(ggml_backend_vk_context * ctx, size_t ne, ggml_type quant) {
// VK_LOG_DEBUG("ggml_vk_test_quantize(" << ne << ")");
// GGML_ASSERT(quant == GGML_TYPE_Q8_1);
//
// const size_t x_sz = sizeof(float) * ne;
// const size_t qx_sz = ne * ggml_type_size(quant)/ggml_blck_size(quant);
// float * x = (float *) malloc(x_sz);
// block_q8_1 * qx = (block_q8_1 *)malloc(qx_sz);
// block_q8_1 * qx_res = (block_q8_1 *)malloc(qx_sz);
// vk_buffer x_buf = ggml_vk_create_buffer_check(ctx->device, x_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
// vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
//
// for (size_t i = 0; i < ne; i++) {
// x[i] = rand() / (float)RAND_MAX;
// }
//
// vk_pipeline p = ggml_vk_get_quantize_pipeline(ctx, quant);
//
// ggml_pipeline_request_descriptor_sets(ctx->device, p, 1);
//
// if (ctx->device->need_compiles) {
// ggml_vk_load_shaders(ctx->device);
// }
//
// ggml_pipeline_allocate_descriptor_sets(ctx->device);
//
// ggml_vk_buffer_write(x_buf, 0, x, x_sz);
//
// vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
// ggml_vk_ctx_begin(ctx->device, subctx);
// ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(x_buf), ggml_vk_subbuffer(qx_buf), ne);
// ggml_vk_ctx_end(subctx);
//
// auto begin = std::chrono::high_resolution_clock::now();
//
// ggml_vk_submit(subctx, ctx->fence);
// VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_quantize waitForFences");
// ctx->device->device.resetFences({ ctx->fence });
//
// auto end = std::chrono::high_resolution_clock::now();
//
// double ms_quant = std::chrono::duration_cast<std::chrono::microseconds>(end-begin).count() / 1000.0;
// ggml_vk_buffer_read(qx_buf, 0, qx, qx_sz);
//
// ggml_vk_quantize_data(x, qx_res, ne, quant);
//
// int first_err = -1;
//
// for (size_t i = 0; i < ne / 32; i++) {
// double error = std::fabs(ggml_fp16_to_fp32(qx_res[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d) - ggml_fp16_to_fp32(qx[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d));
//
// if (first_err < 0 && error > 0.1) {
// first_err = i;
// }
//
// error = std::fabs(ggml_fp16_to_fp32(qx_res[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s) - ggml_fp16_to_fp32(qx[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s));
//
// if (first_err < 0 && error > 0.1) {
// first_err = i;
// }
//
// for (size_t j = 0; j < 32; j++) {
// uint64_t error = std::abs(qx_res[i].qs[j] - qx[i].qs[j]);
//
// if (first_err < 0 && error > 1) {
// first_err = i;
// }
// }
// }
//
// std::cerr << "TEST QUANTIZE " << ggml_type_name(quant) << " time=" << ms_quant << "ms " << (first_err == -1 ? "CORRECT" : "INCORRECT") << std::endl;
//
// if (first_err != -1) {
// std::cerr << "first_error = " << first_err << std::endl;
// std::cerr << "Actual result: " << std::endl << std::endl;
// std::cout << "d=" << ggml_fp16_to_fp32(qx[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d) << " s=" << ggml_fp16_to_fp32(qx[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s) << " ";
// for (size_t j = 0; j < 32; j++) {
// std::cout << " qs" << j << "=" << (uint32_t)qx[first_err].qs[j] << " ";
// }
// std::cerr << std::endl << std::endl << "Expected result: " << std::endl << std::endl;
// std::cout << "d=" << ggml_fp16_to_fp32(qx_res[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d) << " s=" << ggml_fp16_to_fp32(qx_res[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s) << " ";
// for (size_t j = 0; j < 32; j++) {
// std::cout << " qs" << j << "=" << (uint32_t)qx_res[first_err].qs[j] << " ";
// }
// std::cerr << std::endl;
// }
//
// ggml_vk_destroy_buffer(x_buf);
// ggml_vk_destroy_buffer(qx_buf);
//
// free(x);
// free(qx);
// free(qx_res);
// }
static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, size_t n, size_t k, size_t batch, size_t num_it, size_t split_k, size_t shader_size, ggml_type quant, bool mmq = false) {
VK_LOG_DEBUG("ggml_vk_test_dequant_matmul(" << m << ", " << n << ", " << k << ", " << batch << ", " << num_it << ", " << split_k << ", " << ggml_type_name(quant) << ")"); VK_LOG_DEBUG("ggml_vk_test_dequant_matmul(" << m << ", " << n << ", " << k << ", " << batch << ", " << num_it << ", " << split_k << ", " << ggml_type_name(quant) << ")");
const size_t x_ne = m * k * batch; const size_t x_ne = m * k * batch;
const size_t y_ne = k * n * batch; const size_t y_ne = k * n * batch;
const size_t d_ne = m * n * batch; const size_t d_ne = m * n * batch;
vk_matmul_pipeline2 * pipelines;
if (mmq) {
pipelines = ctx->device->pipeline_dequant_mul_mat_mat_q8_1;
} else {
pipelines = ctx->device->pipeline_dequant_mul_mat_mat;
}
const bool fp16acc = ctx->device->fp16;
vk_pipeline p; vk_pipeline p;
std::string shname; std::string shname;
if (shader_size == 0) { if (shader_size == 0) {
p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->a_s : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_s; p = fp16acc ? pipelines[quant].f16acc->a_s : pipelines[quant].f32acc->a_s;
shname = std::string(ggml_type_name(quant)) + "_ALIGNED_S"; shname = std::string(ggml_type_name(quant)) + "_ALIGNED_S";
} else if (shader_size == 1) { } else if (shader_size == 1) {
p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->a_m : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_m; p = fp16acc ? pipelines[quant].f16acc->a_m : pipelines[quant].f32acc->a_m;
shname = std::string(ggml_type_name(quant)) + "_ALIGNED_M"; shname = std::string(ggml_type_name(quant)) + "_ALIGNED_M";
} else if (shader_size == 2) { } else if (shader_size == 2) {
p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->a_l : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_l; p = fp16acc ? pipelines[quant].f16acc->a_l : pipelines[quant].f32acc->a_l;
shname = std::string(ggml_type_name(quant)) + "_ALIGNED_L"; shname = std::string(ggml_type_name(quant)) + "_ALIGNED_L";
} else { } else {
GGML_ASSERT(0); GGML_ASSERT(0);
} }
const size_t kpad = ggml_vk_align_size(k, p->align); const size_t kpad = mmq ? 0 : ggml_vk_align_size(k, p->align);
if (k != kpad) { if (mmq || k != kpad) {
if (shader_size == 0) { if (shader_size == 0) {
p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->s : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->s; p = fp16acc ? pipelines[quant].f16acc->s : pipelines[quant].f32acc->s;
shname = std::string(ggml_type_name(quant)) + "_S"; shname = std::string(ggml_type_name(quant)) + "_S";
} else if (shader_size == 1) { } else if (shader_size == 1) {
p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->m : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->m; p = fp16acc ? pipelines[quant].f16acc->m : pipelines[quant].f32acc->m;
shname = std::string(ggml_type_name(quant)) + "_M"; shname = std::string(ggml_type_name(quant)) + "_M";
} else if (shader_size == 2) { } else if (shader_size == 2) {
p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->l : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->l; p = fp16acc ? pipelines[quant].f16acc->l : pipelines[quant].f32acc->l;
shname = std::string(ggml_type_name(quant)) + "_L"; shname = std::string(ggml_type_name(quant)) + "_L";
} else { } else {
GGML_ASSERT(0); GGML_ASSERT(0);
} }
} }
if (p == nullptr) {
std::cerr << "error: no pipeline for ggml_vk_test_dequant_matmul " << ggml_type_name(quant) << std::endl;
return;
}
const size_t x_sz = sizeof(float) * x_ne; const size_t x_sz = sizeof(float) * x_ne;
const size_t y_sz = sizeof(float) * y_ne; const size_t y_sz = sizeof(float) * y_ne;
const size_t qx_sz = x_ne * ggml_type_size(quant)/ggml_blck_size(quant); const size_t qx_sz = x_ne * ggml_type_size(quant)/ggml_blck_size(quant);
const size_t qy_sz = mmq ? y_ne * ggml_type_size(GGML_TYPE_Q8_1)/ggml_blck_size(GGML_TYPE_Q8_1) : y_sz;
const size_t d_sz = sizeof(float) * d_ne; const size_t d_sz = sizeof(float) * d_ne;
float * x = (float *) malloc(x_sz); float * x = (float *) malloc(x_sz);
float * y = (float *) malloc(y_sz); float * y = (float *) malloc(y_sz);
void * qx = malloc(qx_sz); void * qx = malloc(qx_sz);
vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, vk::MemoryPropertyFlagBits::eDeviceLocal); vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
vk_buffer y_buf = ggml_vk_create_buffer_check(ctx->device, y_sz, vk::MemoryPropertyFlagBits::eDeviceLocal); vk_buffer y_buf = ggml_vk_create_buffer_check(ctx->device, y_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
vk_buffer qy_buf = ggml_vk_create_buffer_check(ctx->device, qy_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
vk_buffer d_buf = ggml_vk_create_buffer_check(ctx->device, d_sz, vk::MemoryPropertyFlagBits::eDeviceLocal); vk_buffer d_buf = ggml_vk_create_buffer_check(ctx->device, d_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
float * d = (float *) malloc(d_sz); float * d = (float *) malloc(d_sz);
float * d_chk = (float *) malloc(d_sz); float * d_chk = (float *) malloc(d_sz);
for (size_t i = 0; i < x_ne; i++) { for (size_t i = 0; i < x_ne; i++) {
x[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f; x[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
// x[i] = (i % k == i / k) ? 1.0f : 0.0f;
// x[i] = i % k;
} }
ggml_vk_quantize_data(x, qx, x_ne, quant); ggml_vk_quantize_data(x, qx, x_ne, quant);
for (size_t i = 0; i < y_ne; i++) { for (size_t i = 0; i < y_ne; i++) {
// y[i] = rand() / (float)RAND_MAX; y[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
y[i] = (i % k == i / k) ? 1.0f : 0.0f; // y[i] = (i % k == i / k) ? 1.0f : 0.0f;
// y[i] = i % k;
} }
ggml_pipeline_request_descriptor_sets(ctx->device, p, num_it); ggml_pipeline_request_descriptor_sets(ctx->device, p, num_it);
@ -7318,6 +7633,13 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
ctx->prealloc_split_k = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne * split_k, vk::MemoryPropertyFlagBits::eDeviceLocal); ctx->prealloc_split_k = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne * split_k, vk::MemoryPropertyFlagBits::eDeviceLocal);
} }
} }
if (mmq) {
ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_quantize_q8_1, num_it);
}
if (ctx->device->need_compiles) {
ggml_vk_load_shaders(ctx->device);
}
ggml_pipeline_allocate_descriptor_sets(ctx->device); ggml_pipeline_allocate_descriptor_sets(ctx->device);
@ -7326,13 +7648,25 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue); vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
ggml_vk_ctx_begin(ctx->device, subctx); ggml_vk_ctx_begin(ctx->device, subctx);
for (size_t i = 0; i < num_it; i++) { if (mmq) {
ggml_vk_matmul( for (size_t i = 0; i < num_it; i++) {
ctx, subctx, p, ggml_vk_subbuffer(qx_buf), ggml_vk_subbuffer(y_buf), ggml_vk_subbuffer(d_buf), ggml_vk_subbuffer(ctx->prealloc_split_k), ggml_vk_quantize_q8_1(ctx, subctx, { y_buf, 0, y_sz }, { qy_buf, 0, qy_sz }, y_ne);
m, n, k, ggml_vk_matmul(
k, k, m, k*m, k*n, m*n, ctx, subctx, p, { qx_buf, 0, qx_sz }, { qy_buf, 0, qy_sz }, { d_buf, 0, d_sz }, { ctx->prealloc_split_k, 0, ctx->prealloc_size_split_k },
split_k, batch, batch, batch, 1, 1, n m, n, k,
); k, k, m, k*m, k*n, m*n,
split_k, batch, batch, batch, 1, 1, n
);
}
} else {
for (size_t i = 0; i < num_it; i++) {
ggml_vk_matmul(
ctx, subctx, p, { qx_buf, 0, qx_sz }, { y_buf, 0, y_sz }, { d_buf, 0, d_sz }, { ctx->prealloc_split_k, 0, ctx->prealloc_size_split_k },
m, n, k,
k, k, m, k*m, k*n, m*n,
split_k, batch, batch, batch, 1, 1, n
);
}
} }
ggml_vk_ctx_end(subctx); ggml_vk_ctx_end(subctx);
@ -7390,7 +7724,11 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
double tflops = 2.0*m*n*k*batch*num_it / (time_ms / 1000.0) / (1000.0*1000.0*1000.0*1000.0); double tflops = 2.0*m*n*k*batch*num_it / (time_ms / 1000.0) / (1000.0*1000.0*1000.0*1000.0);
std::cerr << "TEST MMQ " << shname << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time_ms / num_it << "ms " << tflops << " TFLOPS avg_err=" << avg_err << std::endl; std::cerr << "TEST dequant matmul " << shname;
if (mmq) {
std::cerr << " mmq";
}
std::cerr << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time_ms / num_it << "ms " << tflops << " TFLOPS avg_err=" << avg_err << std::endl;
if (avg_err > 0.01 || std::isnan(avg_err)) { if (avg_err > 0.01 || std::isnan(avg_err)) {
std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl; std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl;
@ -7400,6 +7738,12 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
std::cerr << "Expected result: " << std::endl << std::endl; std::cerr << "Expected result: " << std::endl << std::endl;
ggml_vk_print_matrix_area(d_chk, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); ggml_vk_print_matrix_area(d_chk, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
std::cerr << "src0: " << std::endl << std::endl;
ggml_vk_print_matrix_area(x, GGML_TYPE_F32, k, m, first_err_m, first_err_n, first_err_b);
std::cerr << std::endl;
std::cerr << "src1: " << std::endl << std::endl;
ggml_vk_print_matrix_area(y, GGML_TYPE_F32, k, n, first_err_m, first_err_n, first_err_b);
if (split_k > 1) { if (split_k > 1) {
float * split_k_buf = (float *) malloc(sizeof(float) * d_ne * split_k); float * split_k_buf = (float *) malloc(sizeof(float) * d_ne * split_k);
ggml_vk_buffer_read(ctx->prealloc_split_k, 0, split_k_buf, sizeof(float) * d_ne * split_k); ggml_vk_buffer_read(ctx->prealloc_split_k, 0, split_k_buf, sizeof(float) * d_ne * split_k);
@ -7422,6 +7766,7 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
ggml_vk_destroy_buffer(qx_buf); ggml_vk_destroy_buffer(qx_buf);
ggml_vk_destroy_buffer(y_buf); ggml_vk_destroy_buffer(y_buf);
ggml_vk_destroy_buffer(qy_buf);
ggml_vk_destroy_buffer(d_buf); ggml_vk_destroy_buffer(d_buf);
free(x); free(x);
@ -7454,7 +7799,25 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
128, 49, 49, 128, 49, 49,
4096, 49, 4096, 4096, 49, 4096,
}; };
const size_t num_it = 100; const size_t num_it = 1;
ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q4_0);
ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q4_0);
ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q4_0);
ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q4_0, true);
ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q4_0, true);
ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q4_0, true);
ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q8_0);
ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q8_0);
ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q8_0);
ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q8_0, true);
ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q8_0, true);
ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q8_0, true);
abort();
for (size_t i = 0; i < vals.size(); i += 3) { for (size_t i = 0; i < vals.size(); i += 3) {
ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0); ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0);
@ -9266,7 +9629,7 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
} }
if (tensor->op == GGML_OP_FLASH_ATTN_EXT) { if (tensor->op == GGML_OP_FLASH_ATTN_EXT) {
const float *params = (const float *)tensor->op_params; const float * params = (const float *)tensor->op_params;
tensor_clone = ggml_flash_attn_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3], params[0], params[1], params[2]); tensor_clone = ggml_flash_attn_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3], params[0], params[1], params[2]);
} else if (tensor->op == GGML_OP_MUL_MAT) { } else if (tensor->op == GGML_OP_MUL_MAT) {
tensor_clone = ggml_mul_mat(ggml_ctx, src_clone[0], src_clone[1]); tensor_clone = ggml_mul_mat(ggml_ctx, src_clone[0], src_clone[1]);
@ -9283,7 +9646,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
} else if (tensor->op == GGML_OP_UPSCALE) { } else if (tensor->op == GGML_OP_UPSCALE) {
tensor_clone = ggml_upscale_ext(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]); tensor_clone = ggml_upscale_ext(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
} else if (tensor->op == GGML_OP_SCALE) { } else if (tensor->op == GGML_OP_SCALE) {
tensor_clone = ggml_scale(ggml_ctx, src_clone[0], ((float *)tensor->op_params)[0]); const float * params = (const float *)tensor->op_params;
tensor_clone = ggml_scale(ggml_ctx, src_clone[0], params[0]);
} else if (tensor->op == GGML_OP_SQR) { } else if (tensor->op == GGML_OP_SQR) {
tensor_clone = ggml_sqr(ggml_ctx, src_clone[0]); tensor_clone = ggml_sqr(ggml_ctx, src_clone[0]);
} else if (tensor->op == GGML_OP_SIN) { } else if (tensor->op == GGML_OP_SIN) {
@ -9291,7 +9655,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
} else if (tensor->op == GGML_OP_COS) { } else if (tensor->op == GGML_OP_COS) {
tensor_clone = ggml_cos(ggml_ctx, src_clone[0]); tensor_clone = ggml_cos(ggml_ctx, src_clone[0]);
} else if (tensor->op == GGML_OP_CLAMP) { } else if (tensor->op == GGML_OP_CLAMP) {
tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]); const float * params = (const float *)tensor->op_params;
tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], params[0], params[1]);
} else if (tensor->op == GGML_OP_PAD) { } else if (tensor->op == GGML_OP_PAD) {
tensor_clone = ggml_pad(ggml_ctx, src_clone[0], tensor->ne[0] - src_clone[0]->ne[0], tensor->ne[1] - src_clone[0]->ne[1], tensor->ne[2] - src_clone[0]->ne[2], tensor->ne[3] - src_clone[0]->ne[3]); tensor_clone = ggml_pad(ggml_ctx, src_clone[0], tensor->ne[0] - src_clone[0]->ne[0], tensor->ne[1] - src_clone[0]->ne[1], tensor->ne[2] - src_clone[0]->ne[2], tensor->ne[3] - src_clone[0]->ne[3]);
} else if (tensor->op == GGML_OP_REPEAT) { } else if (tensor->op == GGML_OP_REPEAT) {
@ -9305,7 +9670,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
} else if (tensor->op == GGML_OP_NORM) { } else if (tensor->op == GGML_OP_NORM) {
tensor_clone = ggml_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params); tensor_clone = ggml_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params);
} else if (tensor->op == GGML_OP_GROUP_NORM) { } else if (tensor->op == GGML_OP_GROUP_NORM) {
tensor_clone = ggml_group_norm(ggml_ctx, src_clone[0], *(int *)tensor->op_params, ((float *)tensor->op_params)[1]); const float * float_params = (const float *)tensor->op_params;
tensor_clone = ggml_group_norm(ggml_ctx, src_clone[0], tensor->op_params[0], float_params[1]);
} else if (tensor->op == GGML_OP_RMS_NORM) { } else if (tensor->op == GGML_OP_RMS_NORM) {
tensor_clone = ggml_rms_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params); tensor_clone = ggml_rms_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params);
} else if (tensor->op == GGML_OP_RMS_NORM_BACK) { } else if (tensor->op == GGML_OP_RMS_NORM_BACK) {
@ -9318,14 +9684,15 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
tensor_clone = ggml_l2_norm(ggml_ctx, src_clone[0], eps); tensor_clone = ggml_l2_norm(ggml_ctx, src_clone[0], eps);
} else if (tensor->op == GGML_OP_SOFT_MAX) { } else if (tensor->op == GGML_OP_SOFT_MAX) {
if (src1 != nullptr) { if (src1 != nullptr) {
tensor_clone = ggml_soft_max_ext(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]); const float * params = (const float *)tensor->op_params;
tensor_clone = ggml_soft_max_ext(ggml_ctx, src_clone[0], src_clone[1], params[0], params[1]);
} else { } else {
tensor_clone = ggml_soft_max(ggml_ctx, src_clone[0]); tensor_clone = ggml_soft_max(ggml_ctx, src_clone[0]);
} }
} else if (tensor->op == GGML_OP_SOFT_MAX_BACK) { } else if (tensor->op == GGML_OP_SOFT_MAX_BACK) {
tensor_clone = ggml_soft_max_ext_back(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]); tensor_clone = ggml_soft_max_ext_back(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
} else if (tensor->op == GGML_OP_DIAG_MASK_INF) { } else if (tensor->op == GGML_OP_DIAG_MASK_INF) {
tensor_clone = ggml_diag_mask_inf(ggml_ctx, src_clone[0], *(int *)tensor->op_params); tensor_clone = ggml_diag_mask_inf(ggml_ctx, src_clone[0], tensor->op_params[0]);
} else if (tensor->op == GGML_OP_ROPE || tensor->op == GGML_OP_ROPE_BACK) { } else if (tensor->op == GGML_OP_ROPE || tensor->op == GGML_OP_ROPE_BACK) {
const int n_dims = ((int32_t *) tensor->op_params)[1]; const int n_dims = ((int32_t *) tensor->op_params)[1];
const int mode = ((int32_t *) tensor->op_params)[2]; const int mode = ((int32_t *) tensor->op_params)[2];

View file

@ -212,7 +212,7 @@ void main() {
#else #else
ACC_TYPE sums[WMITER * TM * WNITER * TN]; ACC_TYPE sums[WMITER * TM * WNITER * TN];
FLOAT_TYPE cache_a[WMITER * TM]; FLOAT_TYPE cache_a[WMITER * TM];
FLOAT_TYPE cache_b[WNITER * TN]; FLOAT_TYPE cache_b[TN];
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) { [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
sums[i] = ACC_TYPE(0.0f); sums[i] = ACC_TYPE(0.0f);
@ -744,16 +744,14 @@ void main() {
} }
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
[[unroll]] for (uint j = 0; j < TN; j++) { [[unroll]] for (uint j = 0; j < TN; j++) {
cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * SHMEM_STRIDE + i]; cache_b[j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * SHMEM_STRIDE + i];
} }
}
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
[[unroll]] for (uint cc = 0; cc < TN; cc++) { [[unroll]] for (uint cc = 0; cc < TN; cc++) {
[[unroll]] for (uint cr = 0; cr < TM; cr++) { [[unroll]] for (uint cr = 0; cr < TM; cr++) {
const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr; const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
sums[sums_idx] = fma(ACC_TYPE(cache_a[wsir * TM + cr]), ACC_TYPE(cache_b[wsic * TN + cc]), sums[sums_idx]); sums[sums_idx] = fma(ACC_TYPE(cache_a[wsir * TM + cr]), ACC_TYPE(cache_b[cc]), sums[sums_idx]);
} }
} }
} }

View file

@ -0,0 +1,444 @@
#version 450
#extension GL_EXT_control_flow_attributes : enable
#extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
#extension GL_EXT_integer_dot_product : require
#ifdef FLOAT16
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
#endif
#ifdef COOPMAT
#extension GL_KHR_cooperative_matrix : enable
#extension GL_KHR_memory_scope_semantics : enable
#extension GL_KHR_shader_subgroup_basic : enable
#endif
#ifdef MUL_MAT_ID
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
#endif
#include "types.comp"
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];};
#if defined(A_TYPE_PACKED32)
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
#endif
layout (binding = 1) readonly buffer B {block_q8_1_packed32 data_b[];};
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
#ifdef MUL_MAT_ID
layout (binding = 3) readonly buffer IDS {int data_ids[];};
#endif
layout (push_constant) uniform parameter
{
uint M;
uint N;
uint K;
uint stride_a;
uint stride_b;
uint stride_d;
uint batch_stride_a;
uint batch_stride_b;
uint batch_stride_d;
#ifdef MUL_MAT_ID
uint nei0;
uint nei1;
uint nbi1;
uint ne11;
#else
uint k_split;
uint ne02;
uint ne12;
uint broadcast2;
uint broadcast3;
#endif
} p;
layout (constant_id = 0) const uint BLOCK_SIZE = 64;
layout (constant_id = 1) const uint BM = 64;
layout (constant_id = 2) const uint BN = 64;
// layout (constant_id = 3) const uint BK = 32;
layout (constant_id = 4) const uint WM = 32;
layout (constant_id = 5) const uint WN = 32;
layout (constant_id = 6) const uint WMITER = 2;
layout (constant_id = 7) const uint TM = 4;
layout (constant_id = 8) const uint TN = 2;
layout (constant_id = 9) const uint TK = 1; // Only needed for coopmat
layout (constant_id = 10) const uint WARP = 32;
#define BK 32
#ifdef COOPMAT
#define SHMEM_STRIDE (BK / 4 + 4)
#else
#define SHMEM_STRIDE (BK / 4 + 1)
#endif
shared int32_t buf_a_qs[BM * SHMEM_STRIDE];
#ifndef COOPMAT
#if QUANT_AUXF == 1
shared FLOAT_TYPE buf_a_dm[BM];
#else
shared FLOAT_TYPE_VEC2 buf_a_dm[BM];
#endif
#endif
shared int32_t buf_b_qs[BN * SHMEM_STRIDE];
#ifndef COOPMAT
shared FLOAT_TYPE_VEC2 buf_b_ds[BN];
#endif
#define LOAD_VEC_A (4 * QUANT_R)
#define LOAD_VEC_B 4
#ifdef MUL_MAT_ID
shared u16vec2 row_ids[3072];
#endif // MUL_MAT_ID
#define NUM_WARPS (BLOCK_SIZE / WARP)
#ifdef COOPMAT
shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];
#endif
#include "mul_mmq_funcs.comp"
void main() {
#ifdef NEEDS_INIT_IQ_SHMEM
init_iq_shmem(gl_WorkGroupSize);
#endif
#ifdef MUL_MAT_ID
const uint expert_idx = gl_GlobalInvocationID.z;
#else
const uint batch_idx = gl_GlobalInvocationID.z;
const uint i13 = batch_idx / p.ne12;
const uint i12 = batch_idx % p.ne12;
const uint i03 = i13 / p.broadcast3;
const uint i02 = i12 / p.broadcast2;
const uint batch_idx_a = i03 * p.ne02 + i02;
#endif
const uint blocks_m = (p.M + BM - 1) / BM;
const uint ir = gl_WorkGroupID.x % blocks_m;
const uint ik = gl_WorkGroupID.x / blocks_m;
const uint ic = gl_WorkGroupID.y;
const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER);
const uint WSUBM = WM / WMITER;
const uint WSUBN = WN / WNITER;
#ifdef COOPMAT
const uint warp_i = gl_SubgroupID;
const uint tiw = gl_SubgroupInvocationID;
const uint cms_per_row = WM / TM;
const uint cms_per_col = WN / TN;
const uint storestride = WARP / TM;
const uint store_r = tiw % TM;
const uint store_c = tiw / TM;
#else
const uint warp_i = gl_LocalInvocationID.x / WARP;
const uint tiw = gl_LocalInvocationID.x % WARP;
const uint tiwr = tiw % (WSUBM / TM);
const uint tiwc = tiw / (WSUBM / TM);
#endif
const uint warp_r = warp_i % (BM / WM);
const uint warp_c = warp_i / (BM / WM);
const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A);
const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A);
const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B);
const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B);
const uint loadstride_a = BLOCK_SIZE * LOAD_VEC_A / BK;
const uint loadstride_b = BLOCK_SIZE * LOAD_VEC_B / BK;
#ifdef MUL_MAT_ID
uint _ne1 = 0;
for (uint ii1 = 0; ii1 < p.nei1; ii1++) {
for (uint ii0 = 0; ii0 < p.nei0; ii0++) {
if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) {
row_ids[_ne1] = u16vec2(ii0, ii1);
_ne1++;
}
}
}
barrier();
// Workgroup has no work
if (ic * BN >= _ne1) return;
#endif
#ifdef MUL_MAT_ID
const uint start_k = 0;
const uint end_k = p.K;
#else
const uint start_k = ik * p.k_split;
const uint end_k = min(p.K, (ik + 1) * p.k_split);
#endif
uint pos_a_ib = (
#ifdef MUL_MAT_ID
expert_idx * p.batch_stride_a +
#else
batch_idx_a * p.batch_stride_a +
#endif
ir * BM * p.stride_a + start_k) / BK;
#ifdef MUL_MAT_ID
uint pos_b_ib = 0;
#else
uint pos_b_ib = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / BK;
#endif
#ifdef COOPMAT
coopmat<int8_t, gl_ScopeSubgroup, TM, TK, gl_MatrixUseA> cache_a;
coopmat<int8_t, gl_ScopeSubgroup, TK, TN, gl_MatrixUseB> cache_b;
coopmat<int32_t, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> cm_result;
coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> factors[cms_per_row * cms_per_col];
coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> sums[cms_per_row * cms_per_col];
[[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) {
sums[i] = coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0.0f);
}
#else
int32_t cache_a_qs[WMITER * TM * BK / 4];
int32_t cache_b_qs[TN * BK / 4];
ACC_TYPE sums[WMITER * TM * WNITER * TN];
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
sums[i] = ACC_TYPE(0.0f);
}
#endif
#if QUANT_AUXF == 1
FLOAT_TYPE cache_a_dm[TM];
#else
FLOAT_TYPE_VEC2 cache_a_dm[TM];
#endif
FLOAT_TYPE_VEC2 cache_b_ds[TN];
for (uint block = start_k; block < end_k; block += BK) {
[[unroll]] for (uint l = 0; loadc_a + l < BM; l += loadstride_a) {
const uint ib = pos_a_ib + (loadc_a + l) * p.stride_a / BK;
const uint iqs = loadr_a;
const uint buf_ib = loadc_a + l;
// Should ds be gated to a single thread?
if (iqs == 0) {
#if QUANT_AUXF == 1
buf_a_dm[buf_ib] = get_d(ib);
#else
buf_a_dm[buf_ib] = get_dm(ib);
#endif
}
#if QUANT_R == 1
buf_a_qs[buf_ib * SHMEM_STRIDE + iqs] = repack(ib, iqs);
#else
const i32vec2 vals = repack(ib, iqs);
buf_a_qs[buf_ib * SHMEM_STRIDE + iqs ] = vals.x;
buf_a_qs[buf_ib * SHMEM_STRIDE + iqs + 4] = vals.y;
#endif
}
[[unroll]] for (uint l = 0; loadc_b + l < BN; l += loadstride_b) {
#ifdef MUL_MAT_ID
const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l];
const uint idx = pos_b_ib + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b;
const uint ib = idx / 8;
const uint iqs = idx & 0x7;
#else
const uint ib = pos_b_ib + (loadc_b + l) * p.stride_b / BK;
const uint iqs = loadr_b;
#endif
const uint buf_ib = loadc_b + l;
// Should ds be gated to a single thread?
if (iqs == 0) {
buf_b_ds[buf_ib] = FLOAT_TYPE_VEC2(data_b[ib].ds);
}
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs] = data_b[ib].qs[iqs];
}
barrier();
pos_a_ib += 1;
pos_b_ib += 1;
#ifdef COOPMAT
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
const uint ib_a = warp_r * WM + cm_row * TM;
// Load from shared into cache
coopMatLoad(cache_a, buf_a_qs, ib_a * SHMEM_STRIDE, SHMEM_STRIDE, gl_CooperativeMatrixLayoutRowMajor);
// TODO: only cache values that are actually needed
[[unroll]] for (uint t_idx = 0; t_idx < TM; t_idx++) {
cache_a_dm[t_idx] = buf_a_dm[ib_a + t_idx];
}
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
const uint ib_b = warp_c * WN + cm_col * TN;
coopMatLoad(cache_b, buf_b_qs, ib_b * SHMEM_STRIDE, SHMEM_STRIDE, gl_CooperativeMatrixLayoutColumnMajor);
// TODO: only cache values that are actually needed
[[unroll]] for (uint t_idx = 0; t_idx < TN; t_idx++) {
cache_b_dm[t_idx] = buf_b_d[ib_b + t_idx];
}
cm_result = coopmat<int32_t, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0);
cm_result = coopMatMulAdd(cache_a, cache_b, cm_result);
[[unroll]] for (uint col = 0; col < TN; col += storestride) {
coopmat_stage[warp_i * TM * TN + (store_c + col) * TM + store_r] = ACC_TYPE(float(cache_a_d[store_r]) * float(cache_b_d[store_c + col]));
}
coopMatLoad(factors, coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
sums[cm_col * cms_per_row + cm_row] += factors * coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(cm_result);
}
}
#else
// Load from shared into cache
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
const uint ib = warp_r * WM + wsir * WSUBM + tiwr * TM + cr;
cache_a_dm[wsir * TM + cr] = buf_a_dm[ib];
[[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) {
cache_a_qs[(wsir * TM + cr) * (BK / 4) + idx_k] = buf_a_qs[ib * SHMEM_STRIDE + idx_k];
}
}
}
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
const uint ib = warp_c * WN + wsic * WSUBN + tiwc * TN + cc;
cache_b_ds[cc] = buf_b_ds[ib];
[[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) {
cache_b_qs[cc * (BK / 4) + idx_k] = buf_b_qs[ib * SHMEM_STRIDE + idx_k];
}
}
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
const uint cache_a_idx = wsir * TM + cr;
const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
int32_t q_sum = 0;
[[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) {
q_sum += dotPacked4x8EXT(cache_a_qs[cache_a_idx * (BK / 4) + idx_k],
cache_b_qs[cc * (BK / 4) + idx_k]);
}
sums[sums_idx] += mul_q8_1(q_sum, cache_a_dm[cache_a_idx], cache_b_ds[cc]);
}
}
}
}
#endif
barrier();
}
const uint dr = ir * BM + warp_r * WM;
const uint dc = ic * BN + warp_c * WN;
#ifndef MUL_MAT_ID
const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
#endif
#ifdef COOPMAT
#ifdef MUL_MAT_ID
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
[[unroll]] for (uint col = 0; col < BN; col += storestride) {
const uint row_i = dc + cm_col * TN + col + store_c;
if (row_i >= _ne1) break;
const u16vec2 row_idx = row_ids[row_i];
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
}
}
}
#else
const bool is_aligned = p.stride_d % 4 == 0; // Assumption: D_TYPE == float
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
const bool is_in_bounds = dr + (cm_row + 1) * TM <= p.M && dc + (cm_col + 1) * TN <= p.N;
if (is_aligned && is_in_bounds) {
// Full coopMat is within bounds and stride_d is aligned with 16B
coopmat<D_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> cm_dtype = coopmat<D_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(sums[cm_col * cms_per_row + cm_row]);
coopMatStore(cm_dtype, data_d, offsets + (dc + cm_col * TN) * p.stride_d + dr + cm_row * TM, p.stride_d, gl_CooperativeMatrixLayoutColumnMajor);
} else if (is_in_bounds) {
// Full coopMat is within bounds, but stride_d is not aligned
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
[[unroll]] for (uint col = 0; col < TN; col += storestride) {
data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
}
} else if (dr + cm_row * TM < p.M && dc + cm_col * TN < p.N) {
// Partial coopMat is within bounds
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
[[unroll]] for (uint col = 0; col < TN; col += storestride) {
if (dr + cm_row * TM + store_r < p.M && dc + cm_col * TN + col + store_c < p.N) {
data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
}
}
}
}
}
#endif // MUL_MAT_ID
#else
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
const uint dr_warp = dr + wsir * WSUBM + tiwr * TM;
const uint dc_warp = dc + wsic * WSUBN + tiwc * TN;
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
#ifdef MUL_MAT_ID
const uint row_i = dc_warp + cc;
if (row_i >= _ne1) break;
const u16vec2 row_idx = row_ids[row_i];
#endif // MUL_MAT_ID
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
#ifdef MUL_MAT_ID
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
#else
if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
}
#endif // MUL_MAT_ID
}
}
}
}
#endif // COOPMAT
}

View file

@ -0,0 +1,99 @@
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
#include "types.comp"
// Each iqs value maps to a 32-bit integer
#if defined(DATA_A_Q4_0)
i32vec2 repack(uint ib, uint iqs) {
// Use 2-byte loads since a q4_0 block (18 bytes) is not divisible by 4
const u16vec2 quants = u16vec2(data_a[ib].qs[iqs * 2 ],
data_a[ib].qs[iqs * 2 + 1]);
const uint32_t vui = pack32(quants);
return i32vec2( vui & 0x0F0F0F0F,
(vui >> 4) & 0x0F0F0F0F);
}
ACC_TYPE mul_q8_1(int32_t q_sum, float da, vec2 dsb) {
return ACC_TYPE(da * (float(q_sum) * dsb.x - 8.0 * dsb.y));
}
#endif
#if defined(DATA_A_Q4_1)
i32vec2 repack(uint ib, uint iqs) {
// Use 4-byte loads since a q4_1 block (20 bytes) is divisible by 4
const uint32_t vui = data_a_packed32[ib].qs[iqs];
return i32vec2( vui & 0x0F0F0F0F,
(vui >> 4) & 0x0F0F0F0F);
}
ACC_TYPE mul_q8_1(int32_t q_sum, vec2 dma, vec2 dsb) {
return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y);
}
#endif
#if defined(DATA_A_Q5_0)
i32vec2 repack(uint ib, uint iqs) {
// Use 2-byte loads since a q5_0 block (22 bytes) is not divisible by 4
const u16vec2 quants = u16vec2(data_a[ib].qs[iqs * 2 ],
data_a[ib].qs[iqs * 2 + 1]);
const uint32_t vui = pack32(quants);
const int32_t qh = int32_t((uint32_t(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0]) >> (4 * iqs));
const int32_t v0 = int32_t(vui & 0x0F0F0F0F)
| ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28)
const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F)
| (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28)
return i32vec2(v0, v1);
}
ACC_TYPE mul_q8_1(int32_t q_sum, float da, vec2 dsb) {
return ACC_TYPE(da * (float(q_sum) * dsb.x - 16.0 * dsb.y));
}
#endif
#if defined(DATA_A_Q5_1)
i32vec2 repack(uint ib, uint iqs) {
// Use 4-byte loads since a q5_1 block (24 bytes) is divisible by 4
const uint32_t vui = data_a_packed32[ib].qs[iqs];
const int32_t qh = int32_t(data_a_packed32[ib].qh >> (4 * iqs));
const int32_t v0 = int32_t(vui & 0x0F0F0F0F)
| ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28)
const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F)
| (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28)
return i32vec2(v0, v1);
}
ACC_TYPE mul_q8_1(int32_t q_sum, vec2 dma, vec2 dsb) {
return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y);
}
#endif
#if defined(DATA_A_Q8_0)
int32_t repack(uint ib, uint iqs) {
// Use 2-byte loads since a q8_0 block (34 bytes) is not divisible by 4
return pack32(i16vec2(data_a[ib].qs[iqs * 2 ],
data_a[ib].qs[iqs * 2 + 1]));
}
ACC_TYPE mul_q8_1(int32_t q_sum, float da, vec2 dsb) {
return ACC_TYPE(float(q_sum) * da * dsb.x);
}
#endif
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
FLOAT_TYPE get_d(uint ib) {
return FLOAT_TYPE(data_a[ib].d);
}
#endif
#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1)
FLOAT_TYPE_VEC2 get_dm(uint ib) {
return FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);
}
#endif

View file

@ -0,0 +1,77 @@
#version 450
#extension GL_EXT_control_flow_attributes : require
#extension GL_EXT_shader_16bit_storage : require
layout (push_constant) uniform parameter
{
uint ne;
} p;
#include "types.comp"
layout(constant_id = 0) const uint GROUP_SIZE = 32;
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {vec4 data_a[];};
layout (binding = 1) writeonly buffer D {block_q8_1_packed32 data_b[];};
shared float shmem[GROUP_SIZE];
void quantize() {
const uint wgid = gl_WorkGroupID.x;
const uint tid = gl_LocalInvocationID.x;
// Each thread handles a vec4, so 8 threads handle a block
const uint blocks_per_group = GROUP_SIZE / 8;
const uint block_in_wg = tid / 8;
const uint ib = wgid * blocks_per_group + block_in_wg;
const uint iqs = tid % 8;
if (ib >= gl_NumWorkGroups.x * blocks_per_group) {
return;
}
const uint a_idx = ib * 8 + iqs;
vec4 vals = a_idx < p.ne ? data_a[a_idx] : vec4(0.0f);
const vec4 abs_vals = abs(vals);
// Find absolute max for each block
shmem[tid] = max(max(abs_vals.x, abs_vals.y), max(abs_vals.z, abs_vals.w));
barrier();
[[unroll]] for (uint s = 4; s > 0; s >>= 1) {
if (iqs < s) {
shmem[tid] = max(shmem[tid], shmem[tid + s]);
}
barrier();
}
const float amax = shmem[block_in_wg * 8];
const float d = amax / 127.0;
const float d_inv = d != 0.0 ? 1.0 / d : 0.0;
vals = round(vals * d_inv);
data_b[ib].qs[iqs] = pack32(i8vec4(round(vals)));
barrier();
// Calculate the sum for each block
shmem[tid] = vals.x + vals.y + vals.z + vals.w;
barrier();
[[unroll]] for (uint s = 4; s > 0; s >>= 1) {
if (iqs < s) {
shmem[tid] += shmem[tid + s];
}
barrier();
}
if (iqs == 0) {
const float sum = shmem[tid];
data_b[ib].ds = f16vec2(vec2(d, sum * d));
}
}
void main() {
quantize();
}

View file

@ -0,0 +1,7 @@
#version 460
#extension GL_EXT_integer_dot_product : require
void main()
{
}

View file

@ -1,4 +1,3 @@
#if !defined(GGML_TYPES_COMP) #if !defined(GGML_TYPES_COMP)
#define GGML_TYPES_COMP #define GGML_TYPES_COMP
@ -51,6 +50,7 @@ struct block_q4_0_packed16
#if defined(DATA_A_Q4_0) #if defined(DATA_A_Q4_0)
#define QUANT_K QUANT_K_Q4_0 #define QUANT_K QUANT_K_Q4_0
#define QUANT_R QUANT_R_Q4_0 #define QUANT_R QUANT_R_Q4_0
#define QUANT_AUXF 1
#define A_TYPE block_q4_0 #define A_TYPE block_q4_0
#define A_TYPE_PACKED16 block_q4_0_packed16 #define A_TYPE_PACKED16 block_q4_0_packed16
#endif #endif
@ -72,11 +72,19 @@ struct block_q4_1_packed16
uint16_t qs[16/2]; uint16_t qs[16/2];
}; };
struct block_q4_1_packed32
{
f16vec2 dm;
uint32_t qs[16/4];
};
#if defined(DATA_A_Q4_1) #if defined(DATA_A_Q4_1)
#define QUANT_K QUANT_K_Q4_1 #define QUANT_K QUANT_K_Q4_1
#define QUANT_R QUANT_R_Q4_1 #define QUANT_R QUANT_R_Q4_1
#define QUANT_AUXF 2
#define A_TYPE block_q4_1 #define A_TYPE block_q4_1
#define A_TYPE_PACKED16 block_q4_1_packed16 #define A_TYPE_PACKED16 block_q4_1_packed16
#define A_TYPE_PACKED32 block_q4_1_packed32
#endif #endif
#define QUANT_K_Q5_0 32 #define QUANT_K_Q5_0 32
@ -99,6 +107,7 @@ struct block_q5_0_packed16
#if defined(DATA_A_Q5_0) #if defined(DATA_A_Q5_0)
#define QUANT_K QUANT_K_Q5_0 #define QUANT_K QUANT_K_Q5_0
#define QUANT_R QUANT_R_Q5_0 #define QUANT_R QUANT_R_Q5_0
#define QUANT_AUXF 1
#define A_TYPE block_q5_0 #define A_TYPE block_q5_0
#define A_TYPE_PACKED16 block_q5_0_packed16 #define A_TYPE_PACKED16 block_q5_0_packed16
#endif #endif
@ -122,11 +131,20 @@ struct block_q5_1_packed16
uint16_t qs[16/2]; uint16_t qs[16/2];
}; };
struct block_q5_1_packed32
{
f16vec2 dm;
uint qh;
uint32_t qs[16/4];
};
#if defined(DATA_A_Q5_1) #if defined(DATA_A_Q5_1)
#define QUANT_K QUANT_K_Q5_1 #define QUANT_K QUANT_K_Q5_1
#define QUANT_R QUANT_R_Q5_1 #define QUANT_R QUANT_R_Q5_1
#define QUANT_AUXF 2
#define A_TYPE block_q5_1 #define A_TYPE block_q5_1
#define A_TYPE_PACKED16 block_q5_1_packed16 #define A_TYPE_PACKED16 block_q5_1_packed16
#define A_TYPE_PACKED32 block_q5_1_packed32
#endif #endif
#define QUANT_K_Q8_0 32 #define QUANT_K_Q8_0 32
@ -142,14 +160,40 @@ struct block_q8_0_packed16
float16_t d; float16_t d;
int16_t qs[32/2]; int16_t qs[32/2];
}; };
struct block_q8_0_packed32
{
float16_t d;
int32_t qs[32/4];
};
#if defined(DATA_A_Q8_0) #if defined(DATA_A_Q8_0)
#define QUANT_K QUANT_K_Q8_0 #define QUANT_K QUANT_K_Q8_0
#define QUANT_R QUANT_R_Q8_0 #define QUANT_R QUANT_R_Q8_0
#define QUANT_AUXF 1
#define A_TYPE block_q8_0 #define A_TYPE block_q8_0
#define A_TYPE_PACKED16 block_q8_0_packed16 #define A_TYPE_PACKED16 block_q8_0_packed16
#define A_TYPE_PACKED32 block_q8_0_packed32
#endif #endif
#define QUANT_K_Q8_1 32
#define QUANT_R_Q8_1 1
struct block_q8_1
{
f16vec2 ds;
int8_t qs[32];
};
struct block_q8_1_packed16
{
f16vec2 ds;
int16_t qs[16];
};
struct block_q8_1_packed32
{
f16vec2 ds;
int32_t qs[8];
};
// K-quants // K-quants
#define QUANT_K_Q2_K 256 #define QUANT_K_Q2_K 256

View file

@ -35,6 +35,9 @@
#ifndef GGML_VULKAN_COOPMAT_GLSLC_SUPPORT #ifndef GGML_VULKAN_COOPMAT_GLSLC_SUPPORT
#define GGML_VULKAN_COOPMAT_GLSLC_SUPPORT #define GGML_VULKAN_COOPMAT_GLSLC_SUPPORT
#endif #endif
#ifndef GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT
#define GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT
#endif
std::mutex lock; std::mutex lock;
std::vector<std::pair<std::string, std::string>> shader_fnames; std::vector<std::pair<std::string, std::string>> shader_fnames;
@ -301,7 +304,10 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
std::string aligned_b_type_f32 = coopmat2 ? "float" : fp16 ? "mat2x4" : "vec4"; std::string aligned_b_type_f32 = coopmat2 ? "float" : fp16 ? "mat2x4" : "vec4";
std::string aligned_b_type_f16 = coopmat2 ? "float16_t" : fp16 ? "f16mat2x4" : "f16vec4"; std::string aligned_b_type_f16 = coopmat2 ? "float16_t" : fp16 ? "f16mat2x4" : "f16vec4";
std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", (coopmat2 || fp16) ? "float16_t" : "float"}}; std::map<std::string, std::string> base_dict = {
{"FLOAT_TYPE", (coopmat2 || fp16) ? "float16_t" : "float"},
{"FLOAT_TYPE_VEC2", (coopmat2 || fp16) ? "f16vec2" : "vec2"},
};
std::string shader_name = "matmul"; std::string shader_name = "matmul";
if (matmul_id) { if (matmul_id) {
@ -319,9 +325,7 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
base_dict["COOPMAT"] = "1"; base_dict["COOPMAT"] = "1";
} }
base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float"; const std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp";
std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp";
// Shaders with f16 B_TYPE // Shaders with f16 B_TYPE
string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc); string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc);
@ -345,14 +349,20 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
// don't generate f32 variants for coopmat2 // don't generate f32 variants for coopmat2
if (!coopmat2) { if (!coopmat2) {
string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc); string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
} }
if (tname != "f16" && tname != "f32") { if (tname != "f16" && tname != "f32") {
string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc); string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
} }
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
if (!coopmat && !coopmat2 && !matmul_id && (tname == "q4_0" || tname == "q4_1" || tname == "q5_0" || tname == "q5_1" || tname == "q8_0")) {
string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc);
}
#endif
} }
} }
@ -464,6 +474,7 @@ void process_shaders() {
string_to_spv("acc_f32", "acc.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); string_to_spv("acc_f32", "acc.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {}); string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {});
string_to_spv("quantize_q8_1", "quantize_q8_1.comp", {});
string_to_spv("mul_f32", "mul.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); string_to_spv("mul_f32", "mul.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});

View file

@ -287,6 +287,7 @@ class MODEL_ARCH(IntEnum):
CHAMELEON = auto() CHAMELEON = auto()
WAVTOKENIZER_DEC = auto() WAVTOKENIZER_DEC = auto()
PLM = auto() PLM = auto()
BAILINGMOE = auto()
class MODEL_TENSOR(IntEnum): class MODEL_TENSOR(IntEnum):
@ -490,6 +491,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.CHAMELEON: "chameleon", MODEL_ARCH.CHAMELEON: "chameleon",
MODEL_ARCH.WAVTOKENIZER_DEC: "wavtokenizer-dec", MODEL_ARCH.WAVTOKENIZER_DEC: "wavtokenizer-dec",
MODEL_ARCH.PLM: "plm", MODEL_ARCH.PLM: "plm",
MODEL_ARCH.BAILINGMOE: "bailingmoe",
} }
TENSOR_NAMES: dict[MODEL_TENSOR, str] = { TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
@ -1667,6 +1669,25 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.POSNET_ATTN_V, MODEL_TENSOR.POSNET_ATTN_V,
MODEL_TENSOR.POSNET_ATTN_OUT, MODEL_TENSOR.POSNET_ATTN_OUT,
], ],
MODEL_ARCH.BAILINGMOE: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
MODEL_TENSOR.FFN_GATE_SHEXP,
MODEL_TENSOR.FFN_DOWN_SHEXP,
MODEL_TENSOR.FFN_UP_SHEXP,
],
# TODO # TODO
} }
@ -1719,6 +1740,9 @@ MODEL_TENSOR_SKIP: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_ROT_EMBD, MODEL_TENSOR.ATTN_ROT_EMBD,
], ],
MODEL_ARCH.BAILINGMOE: [
MODEL_TENSOR.ROPE_FREQS,
],
} }
# #

View file

@ -29,6 +29,7 @@ class TensorNameMap:
"shared", # t5 "shared", # t5
"rwkv.embeddings", # rwkv6 "rwkv.embeddings", # rwkv6
"model.embeddings", # rwkv7 "model.embeddings", # rwkv7
"model.word_embeddings", # bailingmoe
), ),
# Token type embeddings # Token type embeddings

Binary file not shown.

BIN
glslc.exe

Binary file not shown.

View file

@ -110,6 +110,8 @@ extern "C" {
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28, LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28,
LLAMA_VOCAB_PRE_TYPE_GPT4O = 29, LLAMA_VOCAB_PRE_TYPE_GPT4O = 29,
LLAMA_VOCAB_PRE_TYPE_SUPERBPE = 30, LLAMA_VOCAB_PRE_TYPE_SUPERBPE = 30,
LLAMA_VOCAB_PRE_TYPE_TRILLION = 31,
LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32,
}; };
enum llama_rope_type { enum llama_rope_type {

View file

@ -66,6 +66,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_CHAMELEON, "chameleon" }, { LLM_ARCH_CHAMELEON, "chameleon" },
{ LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" }, { LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" },
{ LLM_ARCH_PLM, "plm" }, { LLM_ARCH_PLM, "plm" },
{ LLM_ARCH_BAILINGMOE, "bailingmoe" },
{ LLM_ARCH_UNKNOWN, "(unknown)" }, { LLM_ARCH_UNKNOWN, "(unknown)" },
}; };
@ -1409,6 +1410,29 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_POS_NET_ATTN_OUT, "posnet.%d.attn_output" }, { LLM_TENSOR_POS_NET_ATTN_OUT, "posnet.%d.attn_output" },
}, },
}, },
{
LLM_ARCH_BAILINGMOE,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
{ LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" },
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
},
},
{ {
LLM_ARCH_UNKNOWN, LLM_ARCH_UNKNOWN,
{ {

View file

@ -70,6 +70,7 @@ enum llm_arch {
LLM_ARCH_CHAMELEON, LLM_ARCH_CHAMELEON,
LLM_ARCH_WAVTOKENIZER_DEC, LLM_ARCH_WAVTOKENIZER_DEC,
LLM_ARCH_PLM, LLM_ARCH_PLM,
LLM_ARCH_BAILINGMOE,
LLM_ARCH_UNKNOWN, LLM_ARCH_UNKNOWN,
}; };

View file

@ -59,6 +59,8 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
{ "granite", LLM_CHAT_TEMPLATE_GRANITE }, { "granite", LLM_CHAT_TEMPLATE_GRANITE },
{ "gigachat", LLM_CHAT_TEMPLATE_GIGACHAT }, { "gigachat", LLM_CHAT_TEMPLATE_GIGACHAT },
{ "megrez", LLM_CHAT_TEMPLATE_MEGREZ }, { "megrez", LLM_CHAT_TEMPLATE_MEGREZ },
{ "yandex", LLM_CHAT_TEMPLATE_YANDEX },
{ "bailing", LLM_CHAT_TEMPLATE_BAILING },
}; };
llm_chat_template llm_chat_template_from_str(const std::string & name) { llm_chat_template llm_chat_template_from_str(const std::string & name) {
@ -168,6 +170,10 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
return LLM_CHAT_TEMPLATE_GIGACHAT; return LLM_CHAT_TEMPLATE_GIGACHAT;
} else if (tmpl_contains("<|role_start|>")) { } else if (tmpl_contains("<|role_start|>")) {
return LLM_CHAT_TEMPLATE_MEGREZ; return LLM_CHAT_TEMPLATE_MEGREZ;
} else if (tmpl_contains(" Ассистент:")) {
return LLM_CHAT_TEMPLATE_YANDEX;
} else if (tmpl_contains("<role>ASSISTANT</role>") && tmpl_contains("'HUMAN'")) {
return LLM_CHAT_TEMPLATE_BAILING;
} }
return LLM_CHAT_TEMPLATE_UNKNOWN; return LLM_CHAT_TEMPLATE_UNKNOWN;
} }
@ -567,6 +573,41 @@ int32_t llm_chat_apply_template(
if (add_ass) { if (add_ass) {
ss << "<|role_start|>assistant<|role_end|>"; ss << "<|role_start|>assistant<|role_end|>";
} }
} else if (tmpl == LLM_CHAT_TEMPLATE_YANDEX) {
// Yandex template ("\n\n" is defined as EOT token)
ss << "<s>";
for (size_t i = 0; i < chat.size(); i++) {
std::string role(chat[i]->role);
if (role == "user") {
ss << " Пользователь: " << chat[i]->content << "\n\n";
} else if (role == "assistant") {
ss << " Ассистент: " << chat[i]->content << "\n\n";
}
}
// Add generation prompt if needed
if (add_ass) {
ss << " Ассистент:[SEP]";
}
} else if (tmpl == LLM_CHAT_TEMPLATE_BAILING) {
// Bailing (Ling) template
for (auto message : chat) {
std::string role(message->role);
if (role == "user") {
role = "HUMAN";
} else {
std::transform(role.begin(), role.end(), role.begin(), ::toupper);
}
ss << "<role>" << role << "</role>" << message->content;
}
if (add_ass) {
ss << "<role>ASSISTANT</role>";
}
} else { } else {
// template not supported // template not supported
return -1; return -1;
@ -585,4 +626,3 @@ int32_t llama_chat_builtin_templates(const char ** output, size_t len) {
} }
return (int32_t) LLM_CHAT_TEMPLATES.size(); return (int32_t) LLM_CHAT_TEMPLATES.size();
} }

View file

@ -38,6 +38,8 @@ enum llm_chat_template {
LLM_CHAT_TEMPLATE_GRANITE, LLM_CHAT_TEMPLATE_GRANITE,
LLM_CHAT_TEMPLATE_GIGACHAT, LLM_CHAT_TEMPLATE_GIGACHAT,
LLM_CHAT_TEMPLATE_MEGREZ, LLM_CHAT_TEMPLATE_MEGREZ,
LLM_CHAT_TEMPLATE_YANDEX,
LLM_CHAT_TEMPLATE_BAILING,
LLM_CHAT_TEMPLATE_UNKNOWN, LLM_CHAT_TEMPLATE_UNKNOWN,
}; };

View file

@ -93,6 +93,7 @@ const char * llm_type_name(llm_type type) {
case LLM_TYPE_10B_128x3_66B: return "10B+128x3.66B"; case LLM_TYPE_10B_128x3_66B: return "10B+128x3.66B";
case LLM_TYPE_57B_A14B: return "57B.A14B"; case LLM_TYPE_57B_A14B: return "57B.A14B";
case LLM_TYPE_27B: return "27B"; case LLM_TYPE_27B: return "27B";
case LLM_TYPE_290B: return "290B";
default: return "?B"; default: return "?B";
} }
} }
@ -1333,6 +1334,21 @@ void llama_model::load_hparams(llama_model_loader & ml) {
ml.get_key(LLM_KV_ATTENTION_GROUPNORM_GROUPS, hparams.n_norm_groups); ml.get_key(LLM_KV_ATTENTION_GROUPNORM_GROUPS, hparams.n_norm_groups);
ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
} break; } break;
case LLM_ARCH_BAILINGMOE:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead);
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared);
ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale);
ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false);
switch (hparams.n_layer) {
case 28: type = LLM_TYPE_16B; break;
case 88: type = LLM_TYPE_290B; break;
default: type = LLM_TYPE_UNKNOWN;
}
} break;
default: throw std::runtime_error("unsupported model architecture"); default: throw std::runtime_error("unsupported model architecture");
} }
@ -3834,6 +3850,46 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {hparams.convnext.n_embd, n_embd}, 0); output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {hparams.convnext.n_embd, n_embd}, 0);
output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), {n_embd}, 0); output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), {n_embd}, 0);
} break; } break;
case LLM_ARCH_BAILINGMOE:
{
const int64_t n_ff_exp = hparams.n_ff_exp;
const int64_t n_expert_shared = hparams.n_expert_shared;
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
// output
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
for (int i = 0; i < n_layer; ++i) {
auto & layer = layers[i];
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_head * n_rot}, 0);
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_head_kv * n_rot}, 0);
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_head_kv * n_rot}, 0);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_rot, n_embd}, 0);
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
if (n_expert == 0) {
throw std::runtime_error("n_expert must be > 0");
}
if (n_expert_used == 0) {
throw std::runtime_error("n_expert_used must be > 0");
}
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0);
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0);
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0);
layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0);
layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, 0);
layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0);
}
} break;
default: default:
throw std::runtime_error("unknown architecture"); throw std::runtime_error("unknown architecture");
} }
@ -4122,6 +4178,14 @@ void llama_model::print_info() const {
LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale); LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale);
} }
if (arch == LLM_ARCH_BAILINGMOE) {
LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead);
LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp);
LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared);
LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale);
LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm);
}
vocab.print_info(); vocab.print_info();
} }
@ -11917,6 +11981,150 @@ struct llm_build_plm : public llm_graph_context {
} }
}; };
struct llm_build_bailingmoe : public llm_graph_context {
llm_build_bailingmoe(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
ggml_tensor * cur;
ggml_tensor * inpL;
inpL = build_inp_embd(model.tok_embd);
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv_unified();
for (int il = 0; il < n_layer; ++il) {
ggml_tensor * inpSA = inpL;
// norm
cur = build_norm(inpL,
model.layers[il].attn_norm, NULL,
LLM_NORM_RMS, il);
cb(cur, "attn_norm", il);
// self-attention
{
// rope freq factors for llama3; may return nullptr for llama2 and other models
ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
// compute Q and K and RoPE them
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
cb(Qcur, "Qcur", il);
if (model.layers[il].bq) {
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
cb(Qcur, "Qcur", il);
}
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
cb(Kcur, "Kcur", il);
if (model.layers[il].bk) {
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
cb(Kcur, "Kcur", il);
}
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
cb(Vcur, "Vcur", il);
if (model.layers[il].bv) {
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
cb(Vcur, "Vcur", il);
}
Qcur = ggml_reshape_3d(ctx0, Qcur, n_rot, n_head, n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_rot, n_head_kv, n_tokens);
Vcur = ggml_reshape_3d(ctx0, Vcur, n_rot, n_head_kv, n_tokens);
Qcur = ggml_rope_ext(
ctx0, Qcur, inp_pos, rope_factors,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
Kcur = ggml_rope_ext(
ctx0, Kcur, inp_pos, rope_factors,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il);
cur = build_attn(inp_attn, gf,
model.layers[il].wo, model.layers[il].bo,
Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_rot)), il);
}
if (il == n_layer - 1) {
// skip computing output for unused tokens
ggml_tensor * inp_out_ids = build_inp_out_ids();
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
cur = build_norm(ffn_inp,
model.layers[il].ffn_norm, NULL,
LLM_NORM_RMS, il);
cb(cur, "ffn_norm", il);
ggml_tensor * moe_out =
build_moe_ffn(cur,
model.layers[il].ffn_gate_inp,
model.layers[il].ffn_up_exps,
model.layers[il].ffn_gate_exps,
model.layers[il].ffn_down_exps,
nullptr,
n_expert, n_expert_used,
LLM_FFN_SILU, hparams.expert_weights_norm,
false, hparams.expert_weights_scale,
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
il);
cb(moe_out, "ffn_moe_out", il);
// FFN shared expert
{
ggml_tensor * ffn_shexp = build_ffn(cur,
model.layers[il].ffn_up_shexp, NULL, NULL,
model.layers[il].ffn_gate_shexp, NULL, NULL,
model.layers[il].ffn_down_shexp, NULL, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, il);
cb(ffn_shexp, "ffn_shexp", il);
cur = ggml_add(ctx0, moe_out, ffn_shexp);
cb(cur, "ffn_out", il);
}
cur = ggml_add(ctx0, cur, ffn_inp);
cur = build_cvec(cur, il);
cb(cur, "l_out", il);
// input for next layer
inpL = cur;
}
cur = inpL;
cur = build_norm(cur,
model.output_norm, NULL,
LLM_NORM_RMS, -1);
cb(cur, "result_norm", -1);
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cb(cur, "result_output", -1);
res->t_logits = cur;
ggml_build_forward_expand(gf, cur);
}
};
llama_memory_i * llama_model::create_memory() const { llama_memory_i * llama_model::create_memory() const {
llama_memory_i * res; llama_memory_i * res;
@ -12193,6 +12401,10 @@ llm_graph_result_ptr llama_model::build_graph(
{ {
llm = std::make_unique<llm_build_plm>(*this, params, gf); llm = std::make_unique<llm_build_plm>(*this, params, gf);
} break; } break;
case LLM_ARCH_BAILINGMOE:
{
llm = std::make_unique<llm_build_bailingmoe>(*this, params, gf);
} break;
default: default:
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");
} }
@ -12324,6 +12536,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
case LLM_ARCH_GRANITE: case LLM_ARCH_GRANITE:
case LLM_ARCH_GRANITE_MOE: case LLM_ARCH_GRANITE_MOE:
case LLM_ARCH_CHAMELEON: case LLM_ARCH_CHAMELEON:
case LLM_ARCH_BAILINGMOE:
return LLAMA_ROPE_TYPE_NORM; return LLAMA_ROPE_TYPE_NORM;
// the pairs of head values are offset by n_rot/2 // the pairs of head values are offset by n_rot/2

View file

@ -85,6 +85,7 @@ enum llm_type {
LLM_TYPE_10B_128x3_66B, LLM_TYPE_10B_128x3_66B,
LLM_TYPE_57B_A14B, LLM_TYPE_57B_A14B,
LLM_TYPE_27B, LLM_TYPE_27B,
LLM_TYPE_290B,
}; };
struct llama_layer_posnet { struct llama_layer_posnet {

View file

@ -567,6 +567,7 @@ struct llm_tokenizer_bpe : llm_tokenizer {
case LLAMA_VOCAB_PRE_TYPE_MPT: case LLAMA_VOCAB_PRE_TYPE_MPT:
case LLAMA_VOCAB_PRE_TYPE_OLMO: case LLAMA_VOCAB_PRE_TYPE_OLMO:
case LLAMA_VOCAB_PRE_TYPE_JAIS: case LLAMA_VOCAB_PRE_TYPE_JAIS:
case LLAMA_VOCAB_PRE_TYPE_TRILLION:
regex_exprs = { regex_exprs = {
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
}; };
@ -631,6 +632,13 @@ struct llm_tokenizer_bpe : llm_tokenizer {
"(?=(\\d{3})+(?!\\d))", "(?=(\\d{3})+(?!\\d))",
}; };
break; break;
case LLAMA_VOCAB_PRE_TYPE_BAILINGMOE:
regex_exprs = {
// original regex from tokenizer.json
// "'(?i:[sdmt]|ll|ve|re)|[^\\r\\n\\p{L}\\p{N}]?+\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]++[\\r\\n]*|\\s*[\\r\\n]|\\s+(?!\\S)|\\s+"
"'(?:[sSdDmMtT]|[lL][lL]|[vV][eE]|[rR][eE])|[^\\r\\n\\p{L}\\p{N}]?+\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]++[\\r\\n]*|\\s*[\\r\\n]|\\s+(?!\\S)|\\s+",
};
break;
default: default:
// default regex for BPE tokenization pre-processing // default regex for BPE tokenization pre-processing
regex_exprs = { regex_exprs = {
@ -1849,6 +1857,14 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
tokenizer_pre == "superbpe") { tokenizer_pre == "superbpe") {
pre_type = LLAMA_VOCAB_PRE_TYPE_SUPERBPE; pre_type = LLAMA_VOCAB_PRE_TYPE_SUPERBPE;
clean_spaces = false; clean_spaces = false;
} else if (
tokenizer_pre == "trillion") {
pre_type = LLAMA_VOCAB_PRE_TYPE_TRILLION;
clean_spaces = false;
} else if (
tokenizer_pre == "bailingmoe") {
pre_type = LLAMA_VOCAB_PRE_TYPE_BAILINGMOE;
clean_spaces = false;
} else { } else {
throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str())); throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
} }
@ -2029,6 +2045,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|| t.first == "<end_of_turn>" || t.first == "<end_of_turn>"
|| t.first == "<|endoftext|>" || t.first == "<|endoftext|>"
|| t.first == "<EOT>" || t.first == "<EOT>"
|| t.first == "_<EOT>"
|| t.first == "<end▁of▁sentence>" // DeepSeek || t.first == "<end▁of▁sentence>" // DeepSeek
) { ) {
special_eot_id = t.second; special_eot_id = t.second;
@ -2061,6 +2078,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|| t.first == "<fim-prefix>" || t.first == "<fim-prefix>"
|| t.first == "<fim▁begin>" // DeepSeek || t.first == "<fim▁begin>" // DeepSeek
|| t.first == "<PRE>" || t.first == "<PRE>"
|| t.first == "▁<PRE>" // CodeLlama
) { ) {
special_fim_pre_id = t.second; special_fim_pre_id = t.second;
if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
@ -2078,6 +2096,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|| t.first == "<fim-suffix>" || t.first == "<fim-suffix>"
|| t.first == "<fim▁hole>" // DeepSeek || t.first == "<fim▁hole>" // DeepSeek
|| t.first == "<SUF>" || t.first == "<SUF>"
|| t.first == "▁<SUF>" // CodeLlama
) { ) {
special_fim_suf_id = t.second; special_fim_suf_id = t.second;
if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
@ -2095,6 +2114,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|| t.first == "<fim-middle>" || t.first == "<fim-middle>"
|| t.first == "<fim▁end>" // DeepSeek || t.first == "<fim▁end>" // DeepSeek
|| t.first == "<MID>" || t.first == "<MID>"
|| t.first == "▁<MID>" // CodeLlama
) { ) {
special_fim_mid_id = t.second; special_fim_mid_id = t.second;
if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
@ -2179,6 +2199,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|| t.first == "<|endoftext|>" || t.first == "<|endoftext|>"
|| t.first == "<|eom_id|>" || t.first == "<|eom_id|>"
|| t.first == "<EOT>" || t.first == "<EOT>"
|| t.first == "_<EOT>"
) { ) {
special_eog_ids.insert(t.second); special_eog_ids.insert(t.second);
if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {