mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-10 17:14:36 +00:00
Merge branch 'upstream' into concedo_experimental
# Conflicts: # .github/workflows/build-linux-cross.yml # .github/workflows/build.yml # cmake/build-info.cmake # common/CMakeLists.txt # examples/llava/README.md # examples/server/README.md # ggml/CMakeLists.txt # ggml/src/ggml-cuda/CMakeLists.txt # ggml/src/ggml-rpc/ggml-rpc.cpp # ggml/src/ggml-vulkan/CMakeLists.txt # ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt # scripts/sync-ggml.last # tests/test-backend-ops.cpp # tests/test-chat-template.cpp
This commit is contained in:
commit
d8f1f73dd7
25 changed files with 522 additions and 126 deletions
|
@ -2784,7 +2784,10 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||||
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_THREADS_HTTP"));
|
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_THREADS_HTTP"));
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
{"--cache-reuse"}, "N",
|
{"--cache-reuse"}, "N",
|
||||||
string_format("min chunk size to attempt reusing from the cache via KV shifting (default: %d)", params.n_cache_reuse),
|
string_format(
|
||||||
|
"min chunk size to attempt reusing from the cache via KV shifting (default: %d)\n"
|
||||||
|
"[(card)](https://ggml.ai/f0.png)", params.n_cache_reuse
|
||||||
|
),
|
||||||
[](common_params & params, int value) {
|
[](common_params & params, int value) {
|
||||||
params.n_cache_reuse = value;
|
params.n_cache_reuse = value;
|
||||||
}
|
}
|
||||||
|
|
|
@ -419,7 +419,9 @@ class ModelBase:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load_hparams(dir_model: Path):
|
def load_hparams(dir_model: Path):
|
||||||
try:
|
try:
|
||||||
return AutoConfig.from_pretrained(dir_model).to_dict()
|
# for security reason, we don't allow loading remote code by default
|
||||||
|
# if a model need remote code, we will fallback to config.json
|
||||||
|
return AutoConfig.from_pretrained(dir_model, trust_remote_code=False).to_dict()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to load model config from {dir_model}: {e}")
|
logger.warning(f"Failed to load model config from {dir_model}: {e}")
|
||||||
logger.warning("Trying to load config.json instead")
|
logger.warning("Trying to load config.json instead")
|
||||||
|
@ -1899,7 +1901,10 @@ class LlamaModel(TextModel):
|
||||||
raise ValueError(f"Unprocessed experts: {experts}")
|
raise ValueError(f"Unprocessed experts: {experts}")
|
||||||
|
|
||||||
|
|
||||||
@ModelBase.register("LlavaForConditionalGeneration")
|
@ModelBase.register(
|
||||||
|
"LlavaForConditionalGeneration", # pixtral
|
||||||
|
"Mistral3ForConditionalGeneration", # mistral small 3.1
|
||||||
|
)
|
||||||
class LlavaVisionModel(VisionModel):
|
class LlavaVisionModel(VisionModel):
|
||||||
img_break_tok_id = -1
|
img_break_tok_id = -1
|
||||||
|
|
||||||
|
@ -1908,17 +1913,38 @@ class LlavaVisionModel(VisionModel):
|
||||||
if self.hparams["model_type"] == "pixtral":
|
if self.hparams["model_type"] == "pixtral":
|
||||||
# layer_norm_eps is not in config.json, it is hard-coded in modeling_pixtral.py
|
# layer_norm_eps is not in config.json, it is hard-coded in modeling_pixtral.py
|
||||||
self.hparams["layer_norm_eps"] = self.hparams.get("layer_norm_eps", 1e-5)
|
self.hparams["layer_norm_eps"] = self.hparams.get("layer_norm_eps", 1e-5)
|
||||||
self.img_break_tok_id = 12 # see tokenizer_config.json
|
self.img_break_tok_id = self.get_token_id("[IMG_BREAK]")
|
||||||
|
logger.info(f"Image break token id: {self.img_break_tok_id}")
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported model type: {self.hparams['model_type']}")
|
raise ValueError(f"Unsupported model type: {self.hparams['model_type']}")
|
||||||
|
|
||||||
|
def get_token_id(self, token: str) -> int:
|
||||||
|
tokenizer_config_file = self.dir_model / 'tokenizer_config.json'
|
||||||
|
with open(tokenizer_config_file, "r", encoding="utf-8") as f:
|
||||||
|
added_tokens_decoder = json.load(f)['added_tokens_decoder']
|
||||||
|
for id_, token_data in added_tokens_decoder.items():
|
||||||
|
if token_data["content"] == token:
|
||||||
|
return int(id_)
|
||||||
|
raise ValueError(f"Token '{token}' not found in tokenizer config.")
|
||||||
|
|
||||||
def set_gguf_parameters(self):
|
def set_gguf_parameters(self):
|
||||||
super().set_gguf_parameters()
|
super().set_gguf_parameters()
|
||||||
hparams = self.hparams
|
hparams = self.hparams
|
||||||
if hparams["model_type"] == "pixtral":
|
if hparams["model_type"] == "pixtral":
|
||||||
self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.PIXTRAL)
|
self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.PIXTRAL)
|
||||||
self.gguf_writer.add_vision_attention_layernorm_eps(hparams["layer_norm_eps"])
|
self.gguf_writer.add_vision_attention_layernorm_eps(hparams["layer_norm_eps"])
|
||||||
self.gguf_writer.add_vision_use_silu(True)
|
|
||||||
|
# hidden_act
|
||||||
|
if hparams["hidden_act"] == "silu":
|
||||||
|
self.gguf_writer.add_vision_use_silu(True)
|
||||||
|
elif hparams["hidden_act"] == "gelu":
|
||||||
|
self.gguf_writer.add_vision_use_gelu(True)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported hidden_act: {hparams['hidden_act']}")
|
||||||
|
|
||||||
|
# spatial_merge_size
|
||||||
|
if "spatial_merge_size" in self.global_config:
|
||||||
|
self.gguf_writer.add_vision_spatial_merge_size(self.global_config["spatial_merge_size"])
|
||||||
|
|
||||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||||
del bid # unused
|
del bid # unused
|
||||||
|
|
|
@ -31,6 +31,7 @@
|
||||||
#define KEY_FEATURE_LAYER "clip.vision.feature_layer"
|
#define KEY_FEATURE_LAYER "clip.vision.feature_layer"
|
||||||
#define KEY_PROJ_SCALE_FACTOR "clip.vision.projector.scale_factor"
|
#define KEY_PROJ_SCALE_FACTOR "clip.vision.projector.scale_factor"
|
||||||
#define KEY_PROJ_TYPE "clip.projector_type"
|
#define KEY_PROJ_TYPE "clip.projector_type"
|
||||||
|
#define KEY_SPATIAL_MERGE_SIZE "clip.vision.spatial_merge_size"
|
||||||
|
|
||||||
#define KEY_USE_GLU_MLP "clip.use_glu_mlp" // for qwen2.5vl
|
#define KEY_USE_GLU_MLP "clip.use_glu_mlp" // for qwen2.5vl
|
||||||
#define KEY_USE_RMS_NORM "clip.use_rms_norm" // for qwen2.5vl
|
#define KEY_USE_RMS_NORM "clip.use_rms_norm" // for qwen2.5vl
|
||||||
|
@ -68,9 +69,11 @@
|
||||||
#define TN_MVLM_PROJ_BLOCK "mm.model.mb_block.%d.block.%d.%s"
|
#define TN_MVLM_PROJ_BLOCK "mm.model.mb_block.%d.block.%d.%s"
|
||||||
#define TN_MVLM_PROJ_PEG "mm.model.peg.%d.%s"
|
#define TN_MVLM_PROJ_PEG "mm.model.peg.%d.%s"
|
||||||
#define TN_IMAGE_NEWLINE "model.image_newline"
|
#define TN_IMAGE_NEWLINE "model.image_newline"
|
||||||
|
#define TN_MM_INP_NORM "mm.input_norm.weight"
|
||||||
#define TN_MM_INP_PROJ "mm.input_projection.weight" // gemma3
|
#define TN_MM_INP_PROJ "mm.input_projection.weight" // gemma3
|
||||||
#define TN_MM_SOFT_EMB_N "mm.soft_emb_norm.weight" // gemma3
|
#define TN_MM_SOFT_EMB_N "mm.soft_emb_norm.weight" // gemma3
|
||||||
#define TN_MM_PROJECTOR "mm.model.fc.weight" // idefics3
|
#define TN_MM_PROJECTOR "mm.model.fc.weight" // idefics3
|
||||||
|
#define TN_MM_PATCH_MERGER "mm.patch_merger.weight" // mistral small 3.1
|
||||||
#define TN_TOK_IMG_BREAK "v.token_embd.img_break" // pixtral
|
#define TN_TOK_IMG_BREAK "v.token_embd.img_break" // pixtral
|
||||||
|
|
||||||
// mimicpmv
|
// mimicpmv
|
||||||
|
|
|
@ -186,6 +186,7 @@ struct clip_hparams {
|
||||||
std::unordered_set<int32_t> vision_feature_layer;
|
std::unordered_set<int32_t> vision_feature_layer;
|
||||||
int32_t attn_window_size = 0;
|
int32_t attn_window_size = 0;
|
||||||
int32_t n_wa_pattern = 0;
|
int32_t n_wa_pattern = 0;
|
||||||
|
int32_t spatial_merge_size = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct clip_layer {
|
struct clip_layer {
|
||||||
|
@ -246,6 +247,7 @@ struct clip_vision_model {
|
||||||
struct ggml_tensor * projection;
|
struct ggml_tensor * projection;
|
||||||
|
|
||||||
// LLaVA projection
|
// LLaVA projection
|
||||||
|
struct ggml_tensor * mm_input_norm_w = nullptr;
|
||||||
struct ggml_tensor * mm_0_w = nullptr;
|
struct ggml_tensor * mm_0_w = nullptr;
|
||||||
struct ggml_tensor * mm_0_b = nullptr;
|
struct ggml_tensor * mm_0_b = nullptr;
|
||||||
struct ggml_tensor * mm_2_w = nullptr;
|
struct ggml_tensor * mm_2_w = nullptr;
|
||||||
|
@ -325,6 +327,7 @@ struct clip_vision_model {
|
||||||
|
|
||||||
// pixtral
|
// pixtral
|
||||||
struct ggml_tensor * token_embd_img_break = nullptr;
|
struct ggml_tensor * token_embd_img_break = nullptr;
|
||||||
|
struct ggml_tensor * mm_patch_merger_w = nullptr;
|
||||||
};
|
};
|
||||||
|
|
||||||
bool enable_gpu_clip = true;
|
bool enable_gpu_clip = true;
|
||||||
|
@ -662,6 +665,7 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i
|
||||||
const int d_head = hidden_size / n_head;
|
const int d_head = hidden_size / n_head;
|
||||||
const int n_layer = hparams.n_layer;
|
const int n_layer = hparams.n_layer;
|
||||||
const float eps = hparams.eps;
|
const float eps = hparams.eps;
|
||||||
|
const int n_merge = hparams.spatial_merge_size;
|
||||||
|
|
||||||
struct ggml_init_params params = {
|
struct ggml_init_params params = {
|
||||||
/*.mem_size =*/ ctx->buf_compute_meta.size(),
|
/*.mem_size =*/ ctx->buf_compute_meta.size(),
|
||||||
|
@ -746,7 +750,13 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i
|
||||||
{
|
{
|
||||||
ggml_tensor * gate_proj = ggml_mul_mat(ctx0, model.layers[il].ff_gate_w, cur);
|
ggml_tensor * gate_proj = ggml_mul_mat(ctx0, model.layers[il].ff_gate_w, cur);
|
||||||
ggml_tensor * up_proj = ggml_mul_mat(ctx0, model.layers[il].ff_up_w, cur);
|
ggml_tensor * up_proj = ggml_mul_mat(ctx0, model.layers[il].ff_up_w, cur);
|
||||||
gate_proj = ggml_silu(ctx0, gate_proj); // pixtral uses silu
|
if (ctx->use_silu) {
|
||||||
|
gate_proj = ggml_silu(ctx0, gate_proj);
|
||||||
|
} else if (ctx->use_gelu) {
|
||||||
|
gate_proj = ggml_gelu(ctx0, gate_proj);
|
||||||
|
} else {
|
||||||
|
GGML_ABORT("Pixtral: Unsupported activation");
|
||||||
|
}
|
||||||
cur = ggml_mul(ctx0, up_proj, gate_proj);
|
cur = ggml_mul(ctx0, up_proj, gate_proj);
|
||||||
cur = ggml_mul_mat(ctx0, model.layers[il].ff_down_w, cur);
|
cur = ggml_mul_mat(ctx0, model.layers[il].ff_down_w, cur);
|
||||||
}
|
}
|
||||||
|
@ -757,14 +767,42 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i
|
||||||
embeddings = cur;
|
embeddings = cur;
|
||||||
}
|
}
|
||||||
|
|
||||||
// LlavaMultiModalProjector (with GELU activation)
|
// mistral small 3.1 patch merger
|
||||||
|
// ref: https://github.com/huggingface/transformers/blob/7a3e208892c06a5e278144eaf38c8599a42f53e7/src/transformers/models/mistral3/modeling_mistral3.py#L67
|
||||||
|
if (model.mm_patch_merger_w) {
|
||||||
|
GGML_ASSERT(hparams.spatial_merge_size > 0);
|
||||||
|
|
||||||
|
ggml_tensor * cur = embeddings;
|
||||||
|
cur = ggml_mul(ctx0, ggml_rms_norm(ctx0, cur, eps), model.mm_input_norm_w);
|
||||||
|
|
||||||
|
// reshape image tokens to 2D grid
|
||||||
|
cur = ggml_reshape_3d(ctx0, cur, hidden_size, n_patches_x, n_patches_y);
|
||||||
|
cur = ggml_permute(ctx0, cur, 2, 0, 1, 3); // [x, y, hidden_size]
|
||||||
|
cur = ggml_cont(ctx0, cur);
|
||||||
|
|
||||||
|
// torch.nn.functional.unfold is just an im2col under the hood
|
||||||
|
// we just need a dummy kernel to make it work
|
||||||
|
ggml_tensor * kernel = ggml_view_3d(ctx0, cur, n_merge, n_merge, cur->ne[2], 0, 0, 0);
|
||||||
|
cur = ggml_im2col(ctx0, kernel, cur, n_merge, n_merge, 0, 0, 1, 1, true, inp->type);
|
||||||
|
|
||||||
|
// project to hidden_size
|
||||||
|
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], cur->ne[1] * cur->ne[2]);
|
||||||
|
cur = ggml_mul_mat(ctx0, model.mm_patch_merger_w, cur);
|
||||||
|
embeddings = cur;
|
||||||
|
}
|
||||||
|
|
||||||
|
// LlavaMultiModalProjector (always using GELU activation)
|
||||||
{
|
{
|
||||||
embeddings = ggml_mul_mat(ctx0, model.mm_1_w, embeddings);
|
embeddings = ggml_mul_mat(ctx0, model.mm_1_w, embeddings);
|
||||||
embeddings = ggml_add(ctx0, embeddings, model.mm_1_b);
|
if (model.mm_1_b) {
|
||||||
|
embeddings = ggml_add(ctx0, embeddings, model.mm_1_b);
|
||||||
|
}
|
||||||
|
|
||||||
embeddings = ggml_gelu(ctx0, embeddings);
|
embeddings = ggml_gelu(ctx0, embeddings);
|
||||||
embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings);
|
embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings);
|
||||||
embeddings = ggml_add(ctx0, embeddings, model.mm_2_b);
|
if (model.mm_2_b) {
|
||||||
|
embeddings = ggml_add(ctx0, embeddings, model.mm_2_b);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// arrangement of the [IMG_BREAK] token
|
// arrangement of the [IMG_BREAK] token
|
||||||
|
@ -774,11 +812,14 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i
|
||||||
// and then concatenate the [IMG_BREAK] token to the end of each row, aka n_patches_per_row dimension
|
// and then concatenate the [IMG_BREAK] token to the end of each row, aka n_patches_per_row dimension
|
||||||
// after the concatenation, we have a tensor with shape [hidden_size, n_patches_per_row + 1, n_rows]
|
// after the concatenation, we have a tensor with shape [hidden_size, n_patches_per_row + 1, n_rows]
|
||||||
|
|
||||||
|
const int p_y = n_merge > 0 ? n_patches_y / n_merge : n_patches_y;
|
||||||
|
const int p_x = n_merge > 0 ? n_patches_x / n_merge : n_patches_x;
|
||||||
|
const int p_total = p_x * p_y;
|
||||||
const int n_embd_text = embeddings->ne[0];
|
const int n_embd_text = embeddings->ne[0];
|
||||||
const int n_tokens_output = num_patches + n_patches_y - 1; // one [IMG_BREAK] per row, except the last row
|
const int n_tokens_output = p_total + p_y - 1; // one [IMG_BREAK] per row, except the last row
|
||||||
|
|
||||||
ggml_tensor * cur = ggml_reshape_3d(ctx0, embeddings, n_embd_text, n_patches_x, n_patches_y);
|
ggml_tensor * cur = ggml_reshape_3d(ctx0, embeddings, n_embd_text, p_x, p_y);
|
||||||
ggml_tensor * tok = ggml_new_tensor_3d(ctx0, embeddings->type, n_embd_text, 1, n_patches_y);
|
ggml_tensor * tok = ggml_new_tensor_3d(ctx0, embeddings->type, n_embd_text, 1, p_y);
|
||||||
tok = ggml_scale(ctx0, tok, 0.0); // clear the tensor
|
tok = ggml_scale(ctx0, tok, 0.0); // clear the tensor
|
||||||
tok = ggml_add(ctx0, tok, model.token_embd_img_break);
|
tok = ggml_add(ctx0, tok, model.token_embd_img_break);
|
||||||
cur = ggml_concat(ctx0, cur, tok, 1);
|
cur = ggml_concat(ctx0, cur, tok, 1);
|
||||||
|
@ -1780,6 +1821,7 @@ struct clip_model_loader {
|
||||||
case PROJECTOR_TYPE_PIXTRAL:
|
case PROJECTOR_TYPE_PIXTRAL:
|
||||||
{
|
{
|
||||||
hparams.rope_theta = 10000.0f;
|
hparams.rope_theta = 10000.0f;
|
||||||
|
get_u32(KEY_SPATIAL_MERGE_SIZE, hparams.spatial_merge_size, false);
|
||||||
} break;
|
} break;
|
||||||
case PROJECTOR_TYPE_QWEN25VL:
|
case PROJECTOR_TYPE_QWEN25VL:
|
||||||
{
|
{
|
||||||
|
@ -2007,11 +2049,14 @@ struct clip_model_loader {
|
||||||
case PROJECTOR_TYPE_PIXTRAL:
|
case PROJECTOR_TYPE_PIXTRAL:
|
||||||
{
|
{
|
||||||
vision_model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight"));
|
vision_model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight"));
|
||||||
vision_model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias"));
|
vision_model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias"), false);
|
||||||
vision_model.mm_2_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight"));
|
vision_model.mm_2_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight"));
|
||||||
vision_model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"));
|
vision_model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"), false);
|
||||||
// [IMG_BREAK] token embedding
|
// [IMG_BREAK] token embedding
|
||||||
vision_model.token_embd_img_break = get_tensor(TN_TOK_IMG_BREAK);
|
vision_model.token_embd_img_break = get_tensor(TN_TOK_IMG_BREAK);
|
||||||
|
// for mistral small 3.1
|
||||||
|
vision_model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM, false);
|
||||||
|
vision_model.mm_patch_merger_w = get_tensor(TN_MM_PATCH_MERGER, false);
|
||||||
} break;
|
} break;
|
||||||
default:
|
default:
|
||||||
GGML_ASSERT(false && "unknown projector type");
|
GGML_ASSERT(false && "unknown projector type");
|
||||||
|
@ -2653,7 +2698,7 @@ struct llava_uhd {
|
||||||
|
|
||||||
// no pinpoints, dynamically calculate the grid size (e.g. minicpmv)
|
// no pinpoints, dynamically calculate the grid size (e.g. minicpmv)
|
||||||
|
|
||||||
auto best_size = get_best_resize(original_size, slice_size, patch_size, has_slices);
|
auto best_size = get_best_resize(original_size, slice_size, patch_size, !has_slices);
|
||||||
res.overview_size = best_size;
|
res.overview_size = best_size;
|
||||||
|
|
||||||
if (!has_slices) {
|
if (!has_slices) {
|
||||||
|
@ -3067,8 +3112,9 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
|
||||||
} else if (ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
|
} else if (ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
|
||||||
n_patches /= ctx->vision_model.hparams.proj_scale_factor;
|
n_patches /= ctx->vision_model.hparams.proj_scale_factor;
|
||||||
} else if (ctx->proj_type == PROJECTOR_TYPE_PIXTRAL) {
|
} else if (ctx->proj_type == PROJECTOR_TYPE_PIXTRAL) {
|
||||||
int n_patches_x = img->nx / params.patch_size;
|
int n_merge = ctx->vision_model.hparams.spatial_merge_size;
|
||||||
int n_patches_y = img->ny / params.patch_size;
|
int n_patches_x = img->nx / params.patch_size / (n_merge > 0 ? n_merge : 1);
|
||||||
|
int n_patches_y = img->ny / params.patch_size / (n_merge > 0 ? n_merge : 1);
|
||||||
n_patches = n_patches_y*n_patches_x + n_patches_y - 1; // + one [IMG_BREAK] per row, except the last row
|
n_patches = n_patches_y*n_patches_x + n_patches_y - 1; // + one [IMG_BREAK] per row, except the last row
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3654,7 +3700,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
|
||||||
return ctx->vision_model.mm_model_peg_0_b->ne[0];
|
return ctx->vision_model.mm_model_peg_0_b->ne[0];
|
||||||
case PROJECTOR_TYPE_MLP:
|
case PROJECTOR_TYPE_MLP:
|
||||||
case PROJECTOR_TYPE_PIXTRAL:
|
case PROJECTOR_TYPE_PIXTRAL:
|
||||||
return ctx->vision_model.mm_2_b->ne[0];
|
return ctx->vision_model.mm_2_w->ne[1];
|
||||||
case PROJECTOR_TYPE_MLP_NORM:
|
case PROJECTOR_TYPE_MLP_NORM:
|
||||||
return ctx->vision_model.mm_3_b->ne[0];
|
return ctx->vision_model.mm_3_b->ne[0];
|
||||||
case PROJECTOR_TYPE_MINICPMV:
|
case PROJECTOR_TYPE_MINICPMV:
|
||||||
|
|
|
@ -72,6 +72,8 @@ struct mtmd_cli_context {
|
||||||
llama_batch batch;
|
llama_batch batch;
|
||||||
int n_batch;
|
int n_batch;
|
||||||
|
|
||||||
|
std::vector<mtmd_bitmap> bitmaps;
|
||||||
|
|
||||||
// note: we know that gemma3 template is "linear", meaning each turn is completely separated to another
|
// note: we know that gemma3 template is "linear", meaning each turn is completely separated to another
|
||||||
// so here we don't need to keep track of chat history
|
// so here we don't need to keep track of chat history
|
||||||
common_chat_templates_ptr tmpls;
|
common_chat_templates_ptr tmpls;
|
||||||
|
@ -94,6 +96,7 @@ struct mtmd_cli_context {
|
||||||
LOG_ERR("Model does not have chat template.\n");
|
LOG_ERR("Model does not have chat template.\n");
|
||||||
LOG_ERR(" For old llava models, you may need to use '--chat-template vicuna'\n");
|
LOG_ERR(" For old llava models, you may need to use '--chat-template vicuna'\n");
|
||||||
LOG_ERR(" For MobileVLM models, use '--chat-template deepseek'\n");
|
LOG_ERR(" For MobileVLM models, use '--chat-template deepseek'\n");
|
||||||
|
LOG_ERR(" For Mistral Small 3.1, use '--chat-template mistral-v7'\n");
|
||||||
exit(1);
|
exit(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -134,13 +137,22 @@ struct mtmd_cli_context {
|
||||||
antiprompt_tokens.begin()
|
antiprompt_tokens.begin()
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool load_image(const std::string & fname) {
|
||||||
|
mtmd_bitmap bitmap;
|
||||||
|
if (mtmd_helper_bitmap_init_from_file(fname.c_str(), bitmap)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
bitmaps.push_back(std::move(bitmap));
|
||||||
|
return true;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
static int generate_response(mtmd_cli_context & ctx, common_sampler * smpl, int n_predict) {
|
static int generate_response(mtmd_cli_context & ctx, common_sampler * smpl, int n_predict) {
|
||||||
llama_tokens generated_tokens;
|
llama_tokens generated_tokens;
|
||||||
for (int i = 0; i < n_predict; i++) {
|
for (int i = 0; i < n_predict; i++) {
|
||||||
if (i > n_predict || !g_is_generating || g_is_interrupted) {
|
if (i > n_predict || !g_is_generating || g_is_interrupted) {
|
||||||
printf("\n");
|
LOG("\n");
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -149,15 +161,15 @@ static int generate_response(mtmd_cli_context & ctx, common_sampler * smpl, int
|
||||||
common_sampler_accept(smpl, token_id, true);
|
common_sampler_accept(smpl, token_id, true);
|
||||||
|
|
||||||
if (llama_vocab_is_eog(ctx.vocab, token_id) || ctx.check_antiprompt(generated_tokens)) {
|
if (llama_vocab_is_eog(ctx.vocab, token_id) || ctx.check_antiprompt(generated_tokens)) {
|
||||||
printf("\n");
|
LOG("\n");
|
||||||
break; // end of generation
|
break; // end of generation
|
||||||
}
|
}
|
||||||
|
|
||||||
printf("%s", common_token_to_piece(ctx.lctx, token_id).c_str());
|
LOG("%s", common_token_to_piece(ctx.lctx, token_id).c_str());
|
||||||
fflush(stdout);
|
fflush(stdout);
|
||||||
|
|
||||||
if (g_is_interrupted) {
|
if (g_is_interrupted) {
|
||||||
printf("\n");
|
LOG("\n");
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -172,9 +184,7 @@ static int generate_response(mtmd_cli_context & ctx, common_sampler * smpl, int
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
static int eval_message(mtmd_cli_context & ctx, common_chat_msg & msg, std::vector<std::string> & images_fname, bool add_bos = false) {
|
static int eval_message(mtmd_cli_context & ctx, common_chat_msg & msg, bool add_bos = false) {
|
||||||
std::vector<mtmd_bitmap> bitmaps;
|
|
||||||
|
|
||||||
common_chat_templates_inputs tmpl_inputs;
|
common_chat_templates_inputs tmpl_inputs;
|
||||||
tmpl_inputs.messages = {msg};
|
tmpl_inputs.messages = {msg};
|
||||||
tmpl_inputs.add_generation_prompt = true;
|
tmpl_inputs.add_generation_prompt = true;
|
||||||
|
@ -182,15 +192,6 @@ static int eval_message(mtmd_cli_context & ctx, common_chat_msg & msg, std::vect
|
||||||
auto formatted_chat = common_chat_templates_apply(ctx.tmpls.get(), tmpl_inputs);
|
auto formatted_chat = common_chat_templates_apply(ctx.tmpls.get(), tmpl_inputs);
|
||||||
LOG_DBG("formatted_chat.prompt: %s\n", formatted_chat.prompt.c_str());
|
LOG_DBG("formatted_chat.prompt: %s\n", formatted_chat.prompt.c_str());
|
||||||
|
|
||||||
for (auto & fname : images_fname) {
|
|
||||||
mtmd_bitmap bitmap;
|
|
||||||
if (mtmd_helper_bitmap_init_from_file(fname.c_str(), bitmap)) {
|
|
||||||
LOG_ERR("Unable to load image %s\n", fname.c_str());
|
|
||||||
return 2; // image not found
|
|
||||||
}
|
|
||||||
bitmaps.push_back(std::move(bitmap));
|
|
||||||
}
|
|
||||||
|
|
||||||
mtmd_input_text text;
|
mtmd_input_text text;
|
||||||
text.text = formatted_chat.prompt;
|
text.text = formatted_chat.prompt;
|
||||||
text.add_special = add_bos;
|
text.add_special = add_bos;
|
||||||
|
@ -199,12 +200,14 @@ static int eval_message(mtmd_cli_context & ctx, common_chat_msg & msg, std::vect
|
||||||
|
|
||||||
if (g_is_interrupted) return 0;
|
if (g_is_interrupted) return 0;
|
||||||
|
|
||||||
int32_t res = mtmd_tokenize(ctx.ctx_vision.get(), chunks, text, bitmaps);
|
int32_t res = mtmd_tokenize(ctx.ctx_vision.get(), chunks, text, ctx.bitmaps);
|
||||||
if (res != 0) {
|
if (res != 0) {
|
||||||
LOG_ERR("Unable to tokenize prompt, res = %d\n", res);
|
LOG_ERR("Unable to tokenize prompt, res = %d\n", res);
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ctx.bitmaps.clear();
|
||||||
|
|
||||||
if (mtmd_helper_eval(ctx.ctx_vision.get(), ctx.lctx, chunks, ctx.n_past, 0, ctx.n_batch)) {
|
if (mtmd_helper_eval(ctx.ctx_vision.get(), ctx.lctx, chunks, ctx.n_past, 0, ctx.n_batch)) {
|
||||||
LOG_ERR("Unable to eval prompt\n");
|
LOG_ERR("Unable to eval prompt\n");
|
||||||
return 1;
|
return 1;
|
||||||
|
@ -212,6 +215,8 @@ static int eval_message(mtmd_cli_context & ctx, common_chat_msg & msg, std::vect
|
||||||
|
|
||||||
ctx.n_past += mtmd_helper_get_n_pos(chunks);
|
ctx.n_past += mtmd_helper_get_n_pos(chunks);
|
||||||
|
|
||||||
|
LOG("\n");
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -234,7 +239,7 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
|
|
||||||
mtmd_cli_context ctx(params);
|
mtmd_cli_context ctx(params);
|
||||||
printf("%s: %s\n", __func__, params.model.path.c_str());
|
LOG("%s: loading model: %s\n", __func__, params.model.path.c_str());
|
||||||
|
|
||||||
bool is_single_turn = !params.prompt.empty() && !params.image.empty();
|
bool is_single_turn = !params.prompt.empty() && !params.image.empty();
|
||||||
|
|
||||||
|
@ -267,7 +272,12 @@ int main(int argc, char ** argv) {
|
||||||
common_chat_msg msg;
|
common_chat_msg msg;
|
||||||
msg.role = "user";
|
msg.role = "user";
|
||||||
msg.content = params.prompt;
|
msg.content = params.prompt;
|
||||||
if (eval_message(ctx, msg, params.image, true)) {
|
for (const auto & image : params.image) {
|
||||||
|
if (!ctx.load_image(image)) {
|
||||||
|
return 1; // error is already printed by libmtmd
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (eval_message(ctx, msg, true)) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
if (!g_is_interrupted && generate_response(ctx, smpl, n_predict)) {
|
if (!g_is_interrupted && generate_response(ctx, smpl, n_predict)) {
|
||||||
|
@ -282,7 +292,6 @@ int main(int argc, char ** argv) {
|
||||||
LOG("\n");
|
LOG("\n");
|
||||||
|
|
||||||
bool is_first_msg = true;
|
bool is_first_msg = true;
|
||||||
std::vector<std::string> images_fname;
|
|
||||||
std::string content;
|
std::string content;
|
||||||
|
|
||||||
while (!g_is_interrupted) {
|
while (!g_is_interrupted) {
|
||||||
|
@ -307,10 +316,17 @@ int main(int argc, char ** argv) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
g_is_generating = true;
|
g_is_generating = true;
|
||||||
if (line.find("/image") == 0) {
|
if (line == "/image" || line.find("/image ") == 0) {
|
||||||
|
if (line.size() < 8) {
|
||||||
|
LOG_ERR("ERR: Missing image filename\n");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
std::string image = line.substr(7);
|
std::string image = line.substr(7);
|
||||||
images_fname.push_back(string_strip(image));
|
if (ctx.load_image(image)) {
|
||||||
content += "<__image__>";
|
LOG("Image %s loaded\n", image.c_str());
|
||||||
|
content += "<__image__>";
|
||||||
|
}
|
||||||
|
// else, error is already printed by libmtmd
|
||||||
continue;
|
continue;
|
||||||
} else {
|
} else {
|
||||||
content += line;
|
content += line;
|
||||||
|
@ -318,21 +334,14 @@ int main(int argc, char ** argv) {
|
||||||
common_chat_msg msg;
|
common_chat_msg msg;
|
||||||
msg.role = "user";
|
msg.role = "user";
|
||||||
msg.content = content;
|
msg.content = content;
|
||||||
int ret = eval_message(ctx, msg, images_fname, is_first_msg);
|
int ret = eval_message(ctx, msg, is_first_msg);
|
||||||
if (g_is_interrupted) break;
|
|
||||||
if (ret == 2) {
|
|
||||||
// non-fatal error
|
|
||||||
images_fname.clear();
|
|
||||||
content.clear();
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (ret) {
|
if (ret) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
if (g_is_interrupted) break;
|
||||||
if (generate_response(ctx, smpl, n_predict)) {
|
if (generate_response(ctx, smpl, n_predict)) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
images_fname.clear();
|
|
||||||
content.clear();
|
content.clear();
|
||||||
is_first_msg = false;
|
is_first_msg = false;
|
||||||
}
|
}
|
||||||
|
|
|
@ -590,7 +590,7 @@ int32_t mtmd_helper_eval(mtmd_context * ctx,
|
||||||
}
|
}
|
||||||
|
|
||||||
} else if (chunk.type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
|
} else if (chunk.type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
|
||||||
GGML_ASSERT(!is_last && "logits for last image chunk is not yet support");
|
GGML_ASSERT(!is_last && "logits for last image chunk is not yet supported");
|
||||||
GGML_ASSERT(chunk.tokens_image != nullptr);
|
GGML_ASSERT(chunk.tokens_image != nullptr);
|
||||||
int64_t t0 = ggml_time_ms();
|
int64_t t0 = ggml_time_ms();
|
||||||
if (ctx->print_timings) {
|
if (ctx->print_timings) {
|
||||||
|
|
|
@ -59,6 +59,7 @@ add_test "llama-mtmd-cli" "ggml-org/Qwen2.5-VL-3B-Instruct-GGUF:Q4_K_M"
|
||||||
|
|
||||||
# to test the big models, run: ./tests.sh big
|
# to test the big models, run: ./tests.sh big
|
||||||
add_test_big "llama-mtmd-cli" "ggml-org/pixtral-12b-GGUF:Q4_K_M"
|
add_test_big "llama-mtmd-cli" "ggml-org/pixtral-12b-GGUF:Q4_K_M"
|
||||||
|
add_test_big "llama-mtmd-cli" "ggml-org/Mistral-Small-3.1-24B-Instruct-2503-GGUF" "mistral-v7"
|
||||||
|
|
||||||
# these models always give the wrong answer, not sure why
|
# these models always give the wrong answer, not sure why
|
||||||
# add_test "llama-mtmd-cli" "ggml-org/SmolVLM-Instruct-GGUF:Q4_K_M"
|
# add_test "llama-mtmd-cli" "ggml-org/SmolVLM-Instruct-GGUF:Q4_K_M"
|
||||||
|
|
|
@ -816,7 +816,10 @@ static void ggml_gallocr_init_tensor(ggml_gallocr_t galloc, struct ggml_tensor *
|
||||||
static bool ggml_gallocr_node_needs_realloc(ggml_gallocr_t galloc, struct ggml_tensor * node, struct tensor_alloc * talloc) {
|
static bool ggml_gallocr_node_needs_realloc(ggml_gallocr_t galloc, struct ggml_tensor * node, struct tensor_alloc * talloc) {
|
||||||
size_t node_size = 0;
|
size_t node_size = 0;
|
||||||
if (!node->data && !node->view_src) {
|
if (!node->data && !node->view_src) {
|
||||||
GGML_ASSERT(talloc->buffer_id >= 0); // prevent segfault when misusing the API
|
// If we previously had data but don't now then reallocate
|
||||||
|
if (talloc->buffer_id < 0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
node_size = ggml_backend_buft_get_alloc_size(galloc->bufts[talloc->buffer_id], node);
|
node_size = ggml_backend_buft_get_alloc_size(galloc->bufts[talloc->buffer_id], node);
|
||||||
}
|
}
|
||||||
return talloc->size_max >= node_size;
|
return talloc->size_max >= node_size;
|
||||||
|
|
|
@ -67,6 +67,24 @@
|
||||||
#include "ggml-vulkan-shaders-noext.cpp"
|
#include "ggml-vulkan-shaders-noext.cpp"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
// remove this once it's more widely available in the SDK
|
||||||
|
#if !defined(VK_KHR_shader_bfloat16)
|
||||||
|
|
||||||
|
#define VK_KHR_shader_bfloat16 1
|
||||||
|
#define VK_KHR_SHADER_BFLOAT16_SPEC_VERSION 1
|
||||||
|
#define VK_KHR_SHADER_BFLOAT16_EXTENSION_NAME "VK_KHR_shader_bfloat16"
|
||||||
|
#define VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_BFLOAT16_FEATURES_KHR ((VkStructureType)1000141000)
|
||||||
|
#define VK_COMPONENT_TYPE_BFLOAT16_KHR ((VkComponentTypeKHR)1000141000)
|
||||||
|
|
||||||
|
typedef struct VkPhysicalDeviceShaderBfloat16FeaturesKHR {
|
||||||
|
VkStructureType sType;
|
||||||
|
void* pNext;
|
||||||
|
VkBool32 shaderBFloat16Type;
|
||||||
|
VkBool32 shaderBFloat16DotProduct;
|
||||||
|
VkBool32 shaderBFloat16CooperativeMatrix;
|
||||||
|
} VkPhysicalDeviceShaderBfloat16FeaturesKHR;
|
||||||
|
#endif
|
||||||
|
|
||||||
#define ROUNDUP_POW2(M, N) (((M) + (N) - 1) & ~((N) - 1))
|
#define ROUNDUP_POW2(M, N) (((M) + (N) - 1) & ~((N) - 1))
|
||||||
#define CEIL_DIV(M, N) (((M) + (N)-1) / (N))
|
#define CEIL_DIV(M, N) (((M) + (N)-1) / (N))
|
||||||
static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; }
|
static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; }
|
||||||
|
@ -282,8 +300,9 @@ struct vk_device_struct {
|
||||||
bool subgroup_require_full_support;
|
bool subgroup_require_full_support;
|
||||||
|
|
||||||
bool coopmat_support;
|
bool coopmat_support;
|
||||||
bool coopmat_acc_f32_support;
|
bool coopmat_acc_f32_support {};
|
||||||
bool coopmat_acc_f16_support;
|
bool coopmat_acc_f16_support {};
|
||||||
|
bool coopmat_bf16_support {};
|
||||||
uint32_t coopmat_m;
|
uint32_t coopmat_m;
|
||||||
uint32_t coopmat_n;
|
uint32_t coopmat_n;
|
||||||
uint32_t coopmat_k;
|
uint32_t coopmat_k;
|
||||||
|
@ -309,6 +328,7 @@ struct vk_device_struct {
|
||||||
|
|
||||||
vk_matmul_pipeline pipeline_matmul_f32 {};
|
vk_matmul_pipeline pipeline_matmul_f32 {};
|
||||||
vk_matmul_pipeline pipeline_matmul_f32_f16 {};
|
vk_matmul_pipeline pipeline_matmul_f32_f16 {};
|
||||||
|
vk_matmul_pipeline pipeline_matmul_bf16 {};
|
||||||
vk_matmul_pipeline2 pipeline_matmul_f16;
|
vk_matmul_pipeline2 pipeline_matmul_f16;
|
||||||
vk_matmul_pipeline2 pipeline_matmul_f16_f32;
|
vk_matmul_pipeline2 pipeline_matmul_f16_f32;
|
||||||
|
|
||||||
|
@ -317,6 +337,7 @@ struct vk_device_struct {
|
||||||
vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_COUNT];
|
vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_COUNT];
|
||||||
|
|
||||||
vk_matmul_pipeline pipeline_matmul_id_f32 {};
|
vk_matmul_pipeline pipeline_matmul_id_f32 {};
|
||||||
|
vk_matmul_pipeline pipeline_matmul_id_bf16 {};
|
||||||
vk_matmul_pipeline2 pipeline_matmul_id_f16;
|
vk_matmul_pipeline2 pipeline_matmul_id_f16;
|
||||||
vk_matmul_pipeline2 pipeline_matmul_id_f16_f32;
|
vk_matmul_pipeline2 pipeline_matmul_id_f16_f32;
|
||||||
|
|
||||||
|
@ -349,8 +370,8 @@ struct vk_device_struct {
|
||||||
vk_pipeline pipeline_clamp_f32;
|
vk_pipeline pipeline_clamp_f32;
|
||||||
vk_pipeline pipeline_pad_f32;
|
vk_pipeline pipeline_pad_f32;
|
||||||
vk_pipeline pipeline_repeat_f32, pipeline_repeat_back_f32;
|
vk_pipeline pipeline_repeat_f32, pipeline_repeat_back_f32;
|
||||||
vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16;
|
vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f32_bf16;
|
||||||
vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16;
|
vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f32_bf16;
|
||||||
vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT];
|
vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT];
|
||||||
vk_pipeline pipeline_cpy_quant_f32[GGML_TYPE_COUNT];
|
vk_pipeline pipeline_cpy_quant_f32[GGML_TYPE_COUNT];
|
||||||
vk_pipeline pipeline_norm_f32;
|
vk_pipeline pipeline_norm_f32;
|
||||||
|
@ -1807,6 +1828,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
if (!device->pipeline_matmul_id_f32) {
|
if (!device->pipeline_matmul_id_f32) {
|
||||||
device->pipeline_matmul_id_f32 = std::make_shared<vk_matmul_pipeline_struct>();
|
device->pipeline_matmul_id_f32 = std::make_shared<vk_matmul_pipeline_struct>();
|
||||||
}
|
}
|
||||||
|
if (!device->pipeline_matmul_bf16) {
|
||||||
|
device->pipeline_matmul_bf16 = std::make_shared<vk_matmul_pipeline_struct>();
|
||||||
|
}
|
||||||
|
if (!device->pipeline_matmul_id_bf16) {
|
||||||
|
device->pipeline_matmul_id_bf16 = std::make_shared<vk_matmul_pipeline_struct>();
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<std::future<void>> compiles;
|
std::vector<std::future<void>> compiles;
|
||||||
auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const std::string &entrypoint,
|
auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const std::string &entrypoint,
|
||||||
|
@ -1916,6 +1943,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
|
CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
|
||||||
|
|
||||||
CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3)
|
CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3)
|
||||||
|
#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
|
||||||
|
if (device->coopmat_bf16_support) {
|
||||||
|
CREATE_MM(pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3)
|
||||||
|
}
|
||||||
|
#endif
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
||||||
|
@ -1937,6 +1969,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
||||||
|
|
||||||
CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
|
CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
|
||||||
|
#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
|
||||||
|
if (device->coopmat_bf16_support) {
|
||||||
|
CREATE_MM(pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
|
||||||
|
}
|
||||||
|
#endif
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
||||||
|
@ -1990,6 +2027,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
||||||
|
#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
|
||||||
|
if (device->coopmat_bf16_support) {
|
||||||
|
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, )
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
if (device->coopmat_acc_f16_support) {
|
if (device->coopmat_acc_f16_support) {
|
||||||
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
|
@ -2038,6 +2080,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
||||||
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
||||||
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
||||||
|
#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
|
||||||
|
if (device->coopmat_bf16_support) {
|
||||||
|
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
if (device->coopmat_acc_f16_support) {
|
if (device->coopmat_acc_f16_support) {
|
||||||
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
|
@ -2120,6 +2167,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
||||||
|
|
||||||
|
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
||||||
|
|
||||||
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
|
@ -2155,6 +2204,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
||||||
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
||||||
|
|
||||||
|
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
|
|
||||||
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
|
@ -2207,6 +2258,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16_f32.f32acc, matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16_f32.f32acc, matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
||||||
|
|
||||||
|
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
||||||
|
|
||||||
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
|
@ -2242,6 +2295,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
||||||
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
||||||
|
|
||||||
|
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
|
|
||||||
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
|
@ -2262,8 +2317,26 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
#undef CREATE_MM
|
|
||||||
}
|
}
|
||||||
|
// reusing CREATE_MM from the fp32 path
|
||||||
|
if ((device->coopmat2 || device->coopmat_support)
|
||||||
|
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||||
|
&& !device->coopmat_bf16_support
|
||||||
|
#endif
|
||||||
|
) {
|
||||||
|
// use scalar tile sizes
|
||||||
|
l_warptile = { 128, 128, 128, 16, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 };
|
||||||
|
m_warptile = { 128, 64, 64, 16, subgroup_size_8, 32, 2, 4, 2, 1, subgroup_size_8 };
|
||||||
|
s_warptile = { subgroup_size_16, 32, 32, 16, 32, 32, 2, 2, 2, 1, subgroup_size_8 };
|
||||||
|
|
||||||
|
l_wg_denoms = {128, 128, 1 };
|
||||||
|
m_wg_denoms = { 64, 64, 1 };
|
||||||
|
s_wg_denoms = { 32, 32, 1 };
|
||||||
|
|
||||||
|
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
||||||
|
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
|
}
|
||||||
|
#undef CREATE_MM
|
||||||
|
|
||||||
// mul mat vec
|
// mul mat vec
|
||||||
|
|
||||||
|
@ -2282,6 +2355,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
for (uint32_t i = 0; i < mul_mat_vec_max_cols; ++i) {
|
for (uint32_t i = 0; i < mul_mat_vec_max_cols; ++i) {
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f32_f32_"+std::to_string(i+1), mul_mat_vec_f32_f32_f32_len, mul_mat_vec_f32_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f32_f32_"+std::to_string(i+1), mul_mat_vec_f32_f32_f32_len, mul_mat_vec_f32_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f32_f32_"+std::to_string(i+1), mul_mat_vec_f16_f32_f32_len, mul_mat_vec_f16_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f32_f32_"+std::to_string(i+1), mul_mat_vec_f16_f32_f32_len, mul_mat_vec_f16_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f32_f32_"+std::to_string(i+1), mul_mat_vec_bf16_f32_f32_len, mul_mat_vec_bf16_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_0_f32_f32_len, mul_mat_vec_q4_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_0_f32_f32_len, mul_mat_vec_q4_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_1_f32_f32_len, mul_mat_vec_q4_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_1_f32_f32_len, mul_mat_vec_q4_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_0_f32_f32_len, mul_mat_vec_q5_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_0_f32_f32_len, mul_mat_vec_q5_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
|
||||||
|
@ -2304,6 +2378,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32_"+std::to_string(i+1), mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32_"+std::to_string(i+1), mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32_"+std::to_string(i+1), mul_mat_vec_f16_f16_f32_len, mul_mat_vec_f16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32_"+std::to_string(i+1), mul_mat_vec_f16_f16_f32_len, mul_mat_vec_f16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f16_f32_"+std::to_string(i+1), mul_mat_vec_bf16_f16_f32_len, mul_mat_vec_bf16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_0_f16_f32_len, mul_mat_vec_q4_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_0_f16_f32_len, mul_mat_vec_q4_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_1_f16_f32_len, mul_mat_vec_q4_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_1_f16_f32_len, mul_mat_vec_q4_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f16_f32_"+std::to_string(i+1), mul_mat_vec_q5_0_f16_f32_len, mul_mat_vec_q5_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f16_f32_"+std::to_string(i+1), mul_mat_vec_q5_0_f16_f32_len, mul_mat_vec_q5_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
|
||||||
|
@ -2327,6 +2402,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", mul_mat_vec_id_f16_f32_len, mul_mat_vec_id_f16_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", mul_mat_vec_id_f16_f32_len, mul_mat_vec_id_f16_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_BF16], "mul_mat_vec_id_bf16_f32", mul_mat_vec_id_bf16_f32_len, mul_mat_vec_id_bf16_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", mul_mat_vec_id_q4_0_f32_len, mul_mat_vec_id_q4_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", mul_mat_vec_id_q4_0_f32_len, mul_mat_vec_id_q4_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32", mul_mat_vec_id_q4_1_f32_len, mul_mat_vec_id_q4_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32", mul_mat_vec_id_q4_1_f32_len, mul_mat_vec_id_q4_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", mul_mat_vec_id_q5_0_f32_len, mul_mat_vec_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", mul_mat_vec_id_q5_0_f32_len, mul_mat_vec_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
|
||||||
|
@ -2372,6 +2448,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
// get_rows
|
// get_rows
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F32 ], "get_rows_f32", get_rows_f32_len, get_rows_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F32 ], "get_rows_f32", get_rows_f32_len, get_rows_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F16 ], "get_rows_f16", get_rows_f16_len, get_rows_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F16 ], "get_rows_f16", get_rows_f16_len, get_rows_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_BF16], "get_rows_bf16", get_rows_bf16_len, get_rows_bf16_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_0], "get_rows_q4_0", get_rows_q4_0_len, get_rows_q4_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_0], "get_rows_q4_0", get_rows_q4_0_len, get_rows_q4_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_1], "get_rows_q4_1", get_rows_q4_1_len, get_rows_q4_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_1], "get_rows_q4_1", get_rows_q4_1_len, get_rows_q4_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_0], "get_rows_q5_0", get_rows_q5_0_len, get_rows_q5_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_0], "get_rows_q5_0", get_rows_q5_0_len, get_rows_q5_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||||
|
@ -2389,6 +2466,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f32_f32", get_rows_f32_f32_len, get_rows_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f32_f32", get_rows_f32_f32_len, get_rows_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F16 ], "get_rows_f16_f32", get_rows_f16_f32_len, get_rows_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F16 ], "get_rows_f16_f32", get_rows_f16_f32_len, get_rows_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_BF16], "get_rows_bf16_f32", get_rows_bf16_f32_len, get_rows_bf16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_0], "get_rows_q4_0_f32", get_rows_q4_0_f32_len, get_rows_q4_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_0], "get_rows_q4_0_f32", get_rows_q4_0_f32_len, get_rows_q4_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_1], "get_rows_q4_1_f32", get_rows_q4_1_f32_len, get_rows_q4_1_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_1], "get_rows_q4_1_f32", get_rows_q4_1_f32_len, get_rows_q4_1_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_0], "get_rows_q5_0_f32", get_rows_q5_0_f32_len, get_rows_q5_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_0], "get_rows_q5_0_f32", get_rows_q5_0_f32_len, get_rows_q5_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||||
|
@ -2415,7 +2493,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32[i], "mul_mat_vec_p021_f16_f32"+std::to_string(i+1), mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {device->subgroup_size, i + 1}, 1, true);
|
ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32[i], "mul_mat_vec_p021_f16_f32"+std::to_string(i+1), mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {device->subgroup_size, i + 1}, 1, true);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 7 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 9 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
|
||||||
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
||||||
|
@ -2426,10 +2504,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f16, "cpy_f16_f16", cpy_f16_f16_len, cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f16, "cpy_f16_f16", cpy_f16_f16_len, cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_bf16,"cpy_f32_bf16",cpy_f32_bf16_len,cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||||
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f32, "contig_cpy_f32_f32", contig_cpy_f32_f32_len, contig_cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f32, "contig_cpy_f32_f32", contig_cpy_f32_f32_len, contig_cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f16, "contig_cpy_f32_f16", contig_cpy_f32_f16_len, contig_cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f16, "contig_cpy_f32_f16", contig_cpy_f32_f16_len, contig_cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f16, "contig_cpy_f16_f16", contig_cpy_f16_f16_len, contig_cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f16, "contig_cpy_f16_f16", contig_cpy_f16_f16_len, contig_cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_bf16,"contig_cpy_f32_bf16",contig_cpy_f32_bf16_len,contig_cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||||
|
|
||||||
if (device->float_controls_rte_fp16) {
|
if (device->float_controls_rte_fp16) {
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_rte_len, cpy_f32_q4_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_rte_len, cpy_f32_q4_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1);
|
||||||
|
@ -2594,6 +2675,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
||||||
bool coopmat2_support = false;
|
bool coopmat2_support = false;
|
||||||
device->coopmat_support = false;
|
device->coopmat_support = false;
|
||||||
device->integer_dot_product = false;
|
device->integer_dot_product = false;
|
||||||
|
bool bfloat16_support = false;
|
||||||
|
|
||||||
for (const auto& properties : ext_props) {
|
for (const auto& properties : ext_props) {
|
||||||
if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
|
if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
|
||||||
|
@ -2624,6 +2706,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
||||||
!getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) {
|
!getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) {
|
||||||
device->integer_dot_product = true;
|
device->integer_dot_product = true;
|
||||||
#endif
|
#endif
|
||||||
|
} else if (strcmp("VK_KHR_shader_bfloat16", properties.extensionName) == 0 &&
|
||||||
|
!getenv("GGML_VK_DISABLE_BFLOAT16")) {
|
||||||
|
bfloat16_support = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2818,6 +2903,17 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#if defined(VK_KHR_shader_bfloat16)
|
||||||
|
VkPhysicalDeviceShaderBfloat16FeaturesKHR bfloat16_features {};
|
||||||
|
bfloat16_features.pNext = nullptr;
|
||||||
|
bfloat16_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_BFLOAT16_FEATURES_KHR;
|
||||||
|
if (bfloat16_support) {
|
||||||
|
last_struct->pNext = (VkBaseOutStructure *)&bfloat16_features;
|
||||||
|
last_struct = (VkBaseOutStructure *)&bfloat16_features;
|
||||||
|
device_extensions.push_back("VK_KHR_shader_bfloat16");
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
VkPhysicalDeviceMaintenance4Features maint4_features {};
|
VkPhysicalDeviceMaintenance4Features maint4_features {};
|
||||||
maint4_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MAINTENANCE_4_FEATURES;
|
maint4_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MAINTENANCE_4_FEATURES;
|
||||||
if (maintenance4_support) {
|
if (maintenance4_support) {
|
||||||
|
@ -3015,6 +3111,25 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
||||||
device->coopmat_int_n = prop.NSize;
|
device->coopmat_int_n = prop.NSize;
|
||||||
device->coopmat_int_k = prop.KSize;
|
device->coopmat_int_k = prop.KSize;
|
||||||
}
|
}
|
||||||
|
#if defined(VK_KHR_shader_bfloat16) && defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
|
||||||
|
if (prop.AType == VK_COMPONENT_TYPE_BFLOAT16_KHR &&
|
||||||
|
prop.BType == VK_COMPONENT_TYPE_BFLOAT16_KHR &&
|
||||||
|
prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR &&
|
||||||
|
prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR &&
|
||||||
|
(vk::ScopeKHR)prop.scope == vk::ScopeKHR::eSubgroup
|
||||||
|
) {
|
||||||
|
// coopmat sizes not set yet
|
||||||
|
if (device->coopmat_m == 0) {
|
||||||
|
device->coopmat_bf16_support = true;
|
||||||
|
device->coopmat_m = prop.MSize;
|
||||||
|
device->coopmat_n = prop.NSize;
|
||||||
|
device->coopmat_k = prop.KSize;
|
||||||
|
} else if (device->coopmat_m == prop.MSize && device->coopmat_n == prop.NSize && device->coopmat_k == prop.KSize) {
|
||||||
|
// Only enable if shape is identical
|
||||||
|
device->coopmat_bf16_support = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
if (device->coopmat_m == 0 || !device->coopmat_acc_f32_support) {
|
if (device->coopmat_m == 0 || !device->coopmat_acc_f32_support) {
|
||||||
|
@ -3022,11 +3137,19 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
||||||
GGML_LOG_DEBUG("ggml_vulkan: WARNING: No suitable matrix core mode found. Disabling matrix cores.\n");
|
GGML_LOG_DEBUG("ggml_vulkan: WARNING: No suitable matrix core mode found. Disabling matrix cores.\n");
|
||||||
device->coopmat_support = false;
|
device->coopmat_support = false;
|
||||||
}
|
}
|
||||||
|
if (getenv("GGML_VK_DISABLE_BFLOAT16")) {
|
||||||
|
device->coopmat_bf16_support = false;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (device->coopmat_support) {
|
if (device->coopmat_support) {
|
||||||
device_extensions.push_back("VK_KHR_cooperative_matrix");
|
device_extensions.push_back("VK_KHR_cooperative_matrix");
|
||||||
}
|
}
|
||||||
|
#if defined(VK_KHR_shader_bfloat16)
|
||||||
|
if (device->coopmat_bf16_support) {
|
||||||
|
device_extensions.push_back("VK_KHR_shader_bfloat16");
|
||||||
|
}
|
||||||
|
#endif
|
||||||
#endif
|
#endif
|
||||||
device->name = GGML_VK_NAME + std::to_string(idx);
|
device->name = GGML_VK_NAME + std::to_string(idx);
|
||||||
|
|
||||||
|
@ -3483,6 +3606,9 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
|
||||||
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
|
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
|
||||||
return ctx->device->pipeline_matmul_f32_f16;
|
return ctx->device->pipeline_matmul_f32_f16;
|
||||||
}
|
}
|
||||||
|
if (src0_type == GGML_TYPE_BF16 && src1_type == GGML_TYPE_BF16) {
|
||||||
|
return ctx->device->pipeline_matmul_bf16;
|
||||||
|
}
|
||||||
if (prec == GGML_PREC_DEFAULT && ctx->device->fp16 && !(ctx->device->coopmat_support && !ctx->device->coopmat_acc_f16_support)) {
|
if (prec == GGML_PREC_DEFAULT && ctx->device->fp16 && !(ctx->device->coopmat_support && !ctx->device->coopmat_acc_f16_support)) {
|
||||||
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
|
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
|
||||||
return ctx->device->pipeline_matmul_f16_f32.f16acc;
|
return ctx->device->pipeline_matmul_f16_f32.f16acc;
|
||||||
|
@ -3554,6 +3680,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context *
|
||||||
switch (a_type) {
|
switch (a_type) {
|
||||||
case GGML_TYPE_F32:
|
case GGML_TYPE_F32:
|
||||||
case GGML_TYPE_F16:
|
case GGML_TYPE_F16:
|
||||||
|
case GGML_TYPE_BF16:
|
||||||
case GGML_TYPE_Q4_0:
|
case GGML_TYPE_Q4_0:
|
||||||
case GGML_TYPE_Q4_1:
|
case GGML_TYPE_Q4_1:
|
||||||
case GGML_TYPE_Q5_0:
|
case GGML_TYPE_Q5_0:
|
||||||
|
@ -3586,6 +3713,9 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
|
||||||
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
|
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
|
||||||
return ctx->device->pipeline_matmul_id_f32;
|
return ctx->device->pipeline_matmul_id_f32;
|
||||||
}
|
}
|
||||||
|
if (src0_type == GGML_TYPE_BF16 && src1_type == GGML_TYPE_BF16) {
|
||||||
|
return ctx->device->pipeline_matmul_id_bf16;
|
||||||
|
}
|
||||||
if (prec == GGML_PREC_DEFAULT && ctx->device->fp16 && !(ctx->device->coopmat_support && !ctx->device->coopmat_acc_f16_support)) {
|
if (prec == GGML_PREC_DEFAULT && ctx->device->fp16 && !(ctx->device->coopmat_support && !ctx->device->coopmat_acc_f16_support)) {
|
||||||
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
|
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
|
||||||
return ctx->device->pipeline_matmul_id_f16_f32.f16acc;
|
return ctx->device->pipeline_matmul_id_f16_f32.f16acc;
|
||||||
|
@ -3639,6 +3769,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context
|
||||||
switch (a_type) {
|
switch (a_type) {
|
||||||
case GGML_TYPE_F32:
|
case GGML_TYPE_F32:
|
||||||
case GGML_TYPE_F16:
|
case GGML_TYPE_F16:
|
||||||
|
case GGML_TYPE_BF16:
|
||||||
case GGML_TYPE_Q4_0:
|
case GGML_TYPE_Q4_0:
|
||||||
case GGML_TYPE_Q4_1:
|
case GGML_TYPE_Q4_1:
|
||||||
case GGML_TYPE_Q5_0:
|
case GGML_TYPE_Q5_0:
|
||||||
|
@ -4374,6 +4505,13 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const
|
||||||
return ctx->device->pipeline_cpy_f16_f16;
|
return ctx->device->pipeline_cpy_f16_f16;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_BF16) {
|
||||||
|
if (contig) {
|
||||||
|
return ctx->device->pipeline_contig_cpy_f32_bf16;
|
||||||
|
} else {
|
||||||
|
return ctx->device->pipeline_cpy_f32_bf16;
|
||||||
|
}
|
||||||
|
}
|
||||||
if (src->type == GGML_TYPE_F32) {
|
if (src->type == GGML_TYPE_F32) {
|
||||||
switch (to) {
|
switch (to) {
|
||||||
case GGML_TYPE_Q4_0:
|
case GGML_TYPE_Q4_0:
|
||||||
|
@ -4501,8 +4639,12 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
||||||
const bool x_non_contig = (ctx->device->coopmat2 && src0->type == GGML_TYPE_F32) ||
|
const bool x_non_contig = (ctx->device->coopmat2 && src0->type == GGML_TYPE_F32) ||
|
||||||
!ggml_vk_dim01_contiguous(src0);
|
!ggml_vk_dim01_contiguous(src0);
|
||||||
const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) ||
|
const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) ||
|
||||||
|
(src0->type == GGML_TYPE_BF16 && src1->type != GGML_TYPE_BF16) ||
|
||||||
!ggml_vk_dim01_contiguous(src1);
|
!ggml_vk_dim01_contiguous(src1);
|
||||||
|
|
||||||
|
// If src0 is BF16, try to use a BF16 x BF16 multiply
|
||||||
|
ggml_type f16_type = src0->type == GGML_TYPE_BF16 ? GGML_TYPE_BF16 : GGML_TYPE_F16;
|
||||||
|
|
||||||
const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
|
const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
|
||||||
|
|
||||||
bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && (ne11 * ne10) % 4 == 0;
|
bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && (ne11 * ne10) % 4 == 0;
|
||||||
|
@ -4512,25 +4654,25 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
||||||
|
|
||||||
if (mmp == nullptr) {
|
if (mmp == nullptr) {
|
||||||
// Fall back to f16 dequant mul mat
|
// Fall back to f16 dequant mul mat
|
||||||
mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type, (ggml_prec)dst->op_params[0]);
|
mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, y_non_contig ? f16_type : src1->type, (ggml_prec)dst->op_params[0]);
|
||||||
quantize_y = false;
|
quantize_y = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
|
const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
|
||||||
const bool qy_needs_dequant = !quantize_y && ((src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig);
|
const bool qy_needs_dequant = !quantize_y && ((src1->type != f16_type && !y_f32_kernel) || y_non_contig);
|
||||||
|
|
||||||
if (qx_needs_dequant) {
|
if (qx_needs_dequant) {
|
||||||
// Fall back to dequant + f16 mulmat
|
// Fall back to dequant + f16 mulmat
|
||||||
mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, GGML_TYPE_F16, y_f32_kernel ? GGML_TYPE_F32 : GGML_TYPE_F16, (ggml_prec)dst->op_params[0]);
|
mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, f16_type, y_f32_kernel ? GGML_TYPE_F32 : f16_type, (ggml_prec)dst->op_params[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Not implemented
|
// Not implemented
|
||||||
GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
|
GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
|
||||||
|
|
||||||
const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? GGML_TYPE_F16 : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type)));
|
const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? f16_type : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type)));
|
||||||
const bool aligned = !quantize_y && ne10 == kpad && ne01 > 8 && ne11 > 8;
|
const bool aligned = !quantize_y && ne10 == kpad && ne01 > 8 && ne11 > 8;
|
||||||
|
|
||||||
vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type));
|
vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? f16_type : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type));
|
||||||
|
|
||||||
// Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking
|
// Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking
|
||||||
uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) : ne11;
|
uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) : ne11;
|
||||||
|
@ -4551,12 +4693,12 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
||||||
vk_pipeline to_q8_1 = nullptr;
|
vk_pipeline to_q8_1 = nullptr;
|
||||||
|
|
||||||
if (x_non_contig) {
|
if (x_non_contig) {
|
||||||
to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, GGML_TYPE_F16);
|
to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, f16_type);
|
||||||
} else {
|
} else {
|
||||||
to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type);
|
to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type);
|
||||||
}
|
}
|
||||||
if (y_non_contig) {
|
if (y_non_contig) {
|
||||||
to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, GGML_TYPE_F16);
|
to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, f16_type);
|
||||||
} else {
|
} else {
|
||||||
to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
|
to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
|
||||||
}
|
}
|
||||||
|
@ -4973,6 +5115,8 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
|
||||||
const uint64_t nb01 = src0->nb[1];
|
const uint64_t nb01 = src0->nb[1];
|
||||||
const uint64_t nb02 = src0->nb[2];
|
const uint64_t nb02 = src0->nb[2];
|
||||||
|
|
||||||
|
const uint64_t nb12 = src1->nb[2];
|
||||||
|
|
||||||
// const uint64_t ne10 = src1->ne[0];
|
// const uint64_t ne10 = src1->ne[0];
|
||||||
const uint64_t ne11 = src1->ne[1];
|
const uint64_t ne11 = src1->ne[1];
|
||||||
const uint64_t ne12 = src1->ne[2];
|
const uint64_t ne12 = src1->ne[2];
|
||||||
|
@ -4998,6 +5142,7 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
|
||||||
|
|
||||||
const uint32_t row_stride_x = nb01 / sizeof(ggml_fp16_t);
|
const uint32_t row_stride_x = nb01 / sizeof(ggml_fp16_t);
|
||||||
const uint32_t channel_stride_x = nb02 / sizeof(ggml_fp16_t);
|
const uint32_t channel_stride_x = nb02 / sizeof(ggml_fp16_t);
|
||||||
|
const uint32_t channel_stride_y = nb12 / sizeof(float);
|
||||||
|
|
||||||
const uint64_t qx_sz = ggml_nbytes(src0);
|
const uint64_t qx_sz = ggml_nbytes(src0);
|
||||||
const uint64_t qy_sz = ggml_nbytes(src1);
|
const uint64_t qy_sz = ggml_nbytes(src1);
|
||||||
|
@ -5028,7 +5173,7 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
|
||||||
const uint64_t d_shader_offset = d_buf_offset - d_buffer_offset;
|
const uint64_t d_shader_offset = d_buf_offset - d_buffer_offset;
|
||||||
|
|
||||||
// compute
|
// compute
|
||||||
const std::array<uint32_t, 7> pc = { (uint32_t)ne00, (uint32_t)ne01, row_stride_x, channel_stride_x, (uint32_t)(ne12 / ne02), (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)) };
|
const std::array<uint32_t, 9> pc = { (uint32_t)ne00, (uint32_t)ne01, row_stride_x, channel_stride_x, channel_stride_y, (uint32_t)(ne12 / ne02), (uint32_t)ne12, (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)) };
|
||||||
ggml_vk_sync_buffers(subctx);
|
ggml_vk_sync_buffers(subctx);
|
||||||
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_nc_f16_f32,
|
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_nc_f16_f32,
|
||||||
{ vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 7 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, (uint32_t)ne12 });
|
{ vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 7 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, (uint32_t)ne12 });
|
||||||
|
@ -5053,7 +5198,7 @@ static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, c
|
||||||
// mul_mat_vec supports batching ne12*ne13 when ne11==1, or treating ne11 as the batch size (up to four)
|
// mul_mat_vec supports batching ne12*ne13 when ne11==1, or treating ne11 as the batch size (up to four)
|
||||||
// when ne12 and ne13 are one.
|
// when ne12 and ne13 are one.
|
||||||
} else if ((dst->ne[1] == 1 || (dst->ne[1] <= mul_mat_vec_max_cols && src1->ne[2] * src1->ne[3] == 1)) &&
|
} else if ((dst->ne[1] == 1 || (dst->ne[1] <= mul_mat_vec_max_cols && src1->ne[2] * src1->ne[3] == 1)) &&
|
||||||
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) {
|
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || ggml_is_quantized(src0->type))) {
|
||||||
ggml_vk_mul_mat_vec_q_f16(ctx, subctx, src0, src1, dst, dryrun);
|
ggml_vk_mul_mat_vec_q_f16(ctx, subctx, src0, src1, dst, dryrun);
|
||||||
} else {
|
} else {
|
||||||
ggml_vk_mul_mat_q_f16(ctx, subctx, src0, src1, dst, dryrun);
|
ggml_vk_mul_mat_q_f16(ctx, subctx, src0, src1, dst, dryrun);
|
||||||
|
@ -5121,27 +5266,31 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
||||||
const bool x_non_contig = (ctx->device->coopmat2 && src0->type == GGML_TYPE_F32) ||
|
const bool x_non_contig = (ctx->device->coopmat2 && src0->type == GGML_TYPE_F32) ||
|
||||||
!ggml_vk_dim01_contiguous(src0);
|
!ggml_vk_dim01_contiguous(src0);
|
||||||
const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) ||
|
const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) ||
|
||||||
|
(src0->type == GGML_TYPE_BF16 && src1->type != GGML_TYPE_BF16) ||
|
||||||
!ggml_vk_dim01_contiguous(src1);
|
!ggml_vk_dim01_contiguous(src1);
|
||||||
|
|
||||||
|
// If src0 is BF16, try to use a BF16 x BF16 multiply
|
||||||
|
ggml_type f16_type = src0->type == GGML_TYPE_BF16 ? GGML_TYPE_BF16 : GGML_TYPE_F16;
|
||||||
|
|
||||||
const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
|
const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
|
||||||
|
|
||||||
vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type, (ggml_prec)dst->op_params[0]);
|
vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, y_non_contig ? f16_type : src1->type, (ggml_prec)dst->op_params[0]);
|
||||||
|
|
||||||
const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
|
const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
|
||||||
const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig;
|
const bool qy_needs_dequant = (src1->type != f16_type && !y_f32_kernel) || y_non_contig;
|
||||||
|
|
||||||
if (qx_needs_dequant) {
|
if (qx_needs_dequant) {
|
||||||
// Fall back to dequant + f16 mulmat
|
// Fall back to dequant + f16 mulmat
|
||||||
mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, GGML_TYPE_F16, y_f32_kernel ? GGML_TYPE_F32 : GGML_TYPE_F16, (ggml_prec)dst->op_params[0]);
|
mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, f16_type, y_f32_kernel ? GGML_TYPE_F32 : f16_type, (ggml_prec)dst->op_params[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Not implemented
|
// Not implemented
|
||||||
GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
|
GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
|
||||||
|
|
||||||
const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? GGML_TYPE_F16 : src0->type));
|
const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? f16_type : src0->type));
|
||||||
const bool aligned = ne10 == kpad && ne01 > 8 && nei1 > 8;
|
const bool aligned = ne10 == kpad && ne01 > 8 && nei1 > 8;
|
||||||
|
|
||||||
vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type);
|
vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? f16_type : src0->type);
|
||||||
|
|
||||||
// Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking
|
// Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking
|
||||||
uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) :ne11;
|
uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) :ne11;
|
||||||
|
@ -5160,12 +5309,12 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
||||||
vk_pipeline to_fp16_vk_1 = nullptr;
|
vk_pipeline to_fp16_vk_1 = nullptr;
|
||||||
|
|
||||||
if (x_non_contig) {
|
if (x_non_contig) {
|
||||||
to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, GGML_TYPE_F16);
|
to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, f16_type);
|
||||||
} else {
|
} else {
|
||||||
to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type);
|
to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type);
|
||||||
}
|
}
|
||||||
if (y_non_contig) {
|
if (y_non_contig) {
|
||||||
to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, GGML_TYPE_F16);
|
to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, f16_type);
|
||||||
} else {
|
} else {
|
||||||
to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
|
to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
|
||||||
}
|
}
|
||||||
|
@ -9251,6 +9400,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||||
switch (src0_type) {
|
switch (src0_type) {
|
||||||
case GGML_TYPE_F32:
|
case GGML_TYPE_F32:
|
||||||
case GGML_TYPE_F16:
|
case GGML_TYPE_F16:
|
||||||
|
case GGML_TYPE_BF16:
|
||||||
case GGML_TYPE_Q4_0:
|
case GGML_TYPE_Q4_0:
|
||||||
case GGML_TYPE_Q4_1:
|
case GGML_TYPE_Q4_1:
|
||||||
case GGML_TYPE_Q5_0:
|
case GGML_TYPE_Q5_0:
|
||||||
|
@ -9286,10 +9436,15 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||||
if (a->ne[3] != b->ne[3]) {
|
if (a->ne[3] != b->ne[3]) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (!(ggml_vk_dim01_contiguous(op->src[0]) || op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) ||
|
if (!(ggml_vk_dim01_contiguous(op->src[0]) || op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_BF16) ||
|
||||||
!(ggml_vk_dim01_contiguous(op->src[1]) || op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16)) {
|
!(ggml_vk_dim01_contiguous(op->src[1]) || op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
if (op->src[0]->type == GGML_TYPE_BF16 && op->src[1]->type == GGML_TYPE_F16) {
|
||||||
|
// We currently don't have a bf16 x f16 shader, or an fp16->bf16 copy shader.
|
||||||
|
// So don't support this combination for now.
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
} break;
|
} break;
|
||||||
|
@ -9362,6 +9517,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||||
switch (op->src[0]->type) {
|
switch (op->src[0]->type) {
|
||||||
case GGML_TYPE_F32:
|
case GGML_TYPE_F32:
|
||||||
case GGML_TYPE_F16:
|
case GGML_TYPE_F16:
|
||||||
|
case GGML_TYPE_BF16:
|
||||||
case GGML_TYPE_Q4_0:
|
case GGML_TYPE_Q4_0:
|
||||||
case GGML_TYPE_Q4_1:
|
case GGML_TYPE_Q4_1:
|
||||||
case GGML_TYPE_Q5_0:
|
case GGML_TYPE_Q5_0:
|
||||||
|
@ -9392,6 +9548,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||||
switch (src1_type) {
|
switch (src1_type) {
|
||||||
case GGML_TYPE_F32:
|
case GGML_TYPE_F32:
|
||||||
case GGML_TYPE_F16:
|
case GGML_TYPE_F16:
|
||||||
|
case GGML_TYPE_BF16:
|
||||||
case GGML_TYPE_Q4_0:
|
case GGML_TYPE_Q4_0:
|
||||||
case GGML_TYPE_Q4_1:
|
case GGML_TYPE_Q4_1:
|
||||||
case GGML_TYPE_Q5_0:
|
case GGML_TYPE_Q5_0:
|
||||||
|
|
|
@ -18,7 +18,11 @@ void main() {
|
||||||
// fast path for when all four iterations are in-bounds
|
// fast path for when all four iterations are in-bounds
|
||||||
if (idx + (num_iter-1)*num_threads < p.ne) {
|
if (idx + (num_iter-1)*num_threads < p.ne) {
|
||||||
[[unroll]] for (uint i = 0; i < num_iter; ++i) {
|
[[unroll]] for (uint i = 0; i < num_iter; ++i) {
|
||||||
#ifndef OPTIMIZATION_ERROR_WORKAROUND
|
|
||||||
|
#if defined(DATA_D_BF16)
|
||||||
|
float f = float(data_a[get_aoffset() + idx]);
|
||||||
|
data_d[get_doffset() + idx] = D_TYPE(fp32_to_bf16(f));
|
||||||
|
#elif !defined(OPTIMIZATION_ERROR_WORKAROUND)
|
||||||
data_d[get_doffset() + idx] = D_TYPE(data_a[get_aoffset() + idx]);
|
data_d[get_doffset() + idx] = D_TYPE(data_a[get_aoffset() + idx]);
|
||||||
#else
|
#else
|
||||||
data_d[get_doffset() + idx] = data_a[get_aoffset() + idx];
|
data_d[get_doffset() + idx] = data_a[get_aoffset() + idx];
|
||||||
|
@ -31,7 +35,10 @@ void main() {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifndef OPTIMIZATION_ERROR_WORKAROUND
|
#if defined(DATA_D_BF16)
|
||||||
|
float f = float(data_a[get_aoffset() + idx]);
|
||||||
|
data_d[get_doffset() + idx] = D_TYPE(fp32_to_bf16(f));
|
||||||
|
#elif !defined(OPTIMIZATION_ERROR_WORKAROUND)
|
||||||
data_d[get_doffset() + idx] = D_TYPE(data_a[get_aoffset() + idx]);
|
data_d[get_doffset() + idx] = D_TYPE(data_a[get_aoffset() + idx]);
|
||||||
#else
|
#else
|
||||||
data_d[get_doffset() + idx] = data_a[get_aoffset() + idx];
|
data_d[get_doffset() + idx] = data_a[get_aoffset() + idx];
|
||||||
|
|
|
@ -12,7 +12,10 @@ void main() {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifndef OPTIMIZATION_ERROR_WORKAROUND
|
#if defined(DATA_D_BF16)
|
||||||
|
float f = float(data_a[get_aoffset() + src0_idx(idx)]);
|
||||||
|
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(fp32_to_bf16(f));
|
||||||
|
#elif !defined(OPTIMIZATION_ERROR_WORKAROUND)
|
||||||
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
|
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
|
||||||
#else
|
#else
|
||||||
data_d[get_doffset() + dst_idx(idx)] = data_a[get_aoffset() + src0_idx(idx)];
|
data_d[get_doffset() + dst_idx(idx)] = data_a[get_aoffset() + src0_idx(idx)];
|
||||||
|
|
|
@ -23,6 +23,12 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#if defined(DATA_A_BF16)
|
||||||
|
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
|
||||||
|
return vec2(bf16_to_fp32(data_a[a_offset + ib]), bf16_to_fp32(data_a[a_offset + ib + 1]));
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
#if defined(DATA_A_Q4_0)
|
#if defined(DATA_A_Q4_0)
|
||||||
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
|
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
|
||||||
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
|
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
|
||||||
|
@ -428,7 +434,7 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if defined(DATA_A_F32) || defined(DATA_A_F16)
|
#if defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16)
|
||||||
vec2 get_dm(uint ib, uint a_offset) {
|
vec2 get_dm(uint ib, uint a_offset) {
|
||||||
return vec2(0, 0);
|
return vec2(0, 0);
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,9 +20,14 @@ void main() {
|
||||||
const uint a_offset = get_aoffset() + i01*p.nb01 + i11*p.nb02 + i12*p.nb03;
|
const uint a_offset = get_aoffset() + i01*p.nb01 + i11*p.nb02 + i12*p.nb03;
|
||||||
const uint d_offset = get_doffset() + i10*p.nb21 + i11*p.nb22 + i12*p.nb23;
|
const uint d_offset = get_doffset() + i10*p.nb21 + i11*p.nb22 + i12*p.nb23;
|
||||||
|
|
||||||
#ifndef OPTIMIZATION_ERROR_WORKAROUND
|
#if defined(DATA_A_BF16)
|
||||||
data_d[d_offset + i00] = D_TYPE(data_a[a_offset + i00]);
|
FLOAT_TYPE v = FLOAT_TYPE(bf16_to_fp32(data_a[a_offset + i00]));
|
||||||
#else
|
#else
|
||||||
data_d[d_offset + i00] = data_a[a_offset + i00];
|
FLOAT_TYPE v = FLOAT_TYPE(data_a[a_offset + i00]);
|
||||||
|
#endif
|
||||||
|
#ifndef OPTIMIZATION_ERROR_WORKAROUND
|
||||||
|
data_d[d_offset + i00] = D_TYPE(v);
|
||||||
|
#else
|
||||||
|
data_d[d_offset + i00] = D_TYPE(v);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
#if !defined(DATA_A_F32) && !defined(DATA_A_F16)
|
#if !defined(DATA_A_F32) && !defined(DATA_A_F16) && !defined(DATA_A_BF16)
|
||||||
#define K_PER_ITER 8
|
#define K_PER_ITER 8
|
||||||
#else
|
#else
|
||||||
#define K_PER_ITER 2
|
#define K_PER_ITER 2
|
||||||
|
|
|
@ -21,7 +21,9 @@ layout (push_constant) uniform parameter
|
||||||
uint nrows_x;
|
uint nrows_x;
|
||||||
uint row_stride_x;
|
uint row_stride_x;
|
||||||
uint channel_stride_x;
|
uint channel_stride_x;
|
||||||
|
uint channel_stride_y;
|
||||||
uint channel_x_divisor;
|
uint channel_x_divisor;
|
||||||
|
uint ne12;
|
||||||
uint b_offset;
|
uint b_offset;
|
||||||
uint d_offset;
|
uint d_offset;
|
||||||
} p;
|
} p;
|
||||||
|
@ -33,6 +35,7 @@ void main() {
|
||||||
const uint row_x = gl_GlobalInvocationID.y;
|
const uint row_x = gl_GlobalInvocationID.y;
|
||||||
const uint channel = gl_GlobalInvocationID.z;
|
const uint channel = gl_GlobalInvocationID.z;
|
||||||
const uint channel_x = channel / p.channel_x_divisor;
|
const uint channel_x = channel / p.channel_x_divisor;
|
||||||
|
const uint channel_y = channel % p.ne12;
|
||||||
|
|
||||||
const uint nrows_y = p.ncols_x;
|
const uint nrows_y = p.ncols_x;
|
||||||
const uint nrows_dst = p.nrows_x;
|
const uint nrows_dst = p.nrows_x;
|
||||||
|
@ -56,7 +59,7 @@ void main() {
|
||||||
const uint row_y = col_x;
|
const uint row_y = col_x;
|
||||||
|
|
||||||
const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
|
const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
|
||||||
const uint iy = channel*nrows_y + row_y;
|
const uint iy = channel_y*p.channel_stride_y + row_y;
|
||||||
|
|
||||||
const vec4 av4 = vec4(data_a_v4[ix / 4]);
|
const vec4 av4 = vec4(data_a_v4[ix / 4]);
|
||||||
const vec4 bv4 = vec4(data_b_v4[iy / 4]);
|
const vec4 bv4 = vec4(data_b_v4[iy / 4]);
|
||||||
|
@ -72,7 +75,7 @@ void main() {
|
||||||
const uint row_y = col_x;
|
const uint row_y = col_x;
|
||||||
|
|
||||||
const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
|
const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
|
||||||
const uint iy = channel*nrows_y + row_y;
|
const uint iy = channel_y*p.channel_stride_y + row_y;
|
||||||
|
|
||||||
const vec4 av4 = vec4(data_a_v4[ix / 4]);
|
const vec4 av4 = vec4(data_a_v4[ix / 4]);
|
||||||
const vec4 bv4 = vec4(data_b_v4[iy / 4]);
|
const vec4 bv4 = vec4(data_b_v4[iy / 4]);
|
||||||
|
@ -89,7 +92,7 @@ void main() {
|
||||||
const uint row_y = col_x;
|
const uint row_y = col_x;
|
||||||
|
|
||||||
const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
|
const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
|
||||||
const uint iy = channel*nrows_y + row_y;
|
const uint iy = channel_y*p.channel_stride_y + row_y;
|
||||||
|
|
||||||
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]);
|
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]);
|
||||||
|
|
||||||
|
|
|
@ -10,6 +10,10 @@
|
||||||
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#if defined(DATA_A_BF16) && defined(COOPMAT)
|
||||||
|
#extension GL_EXT_bfloat16 : enable
|
||||||
|
#endif
|
||||||
|
|
||||||
#ifdef COOPMAT
|
#ifdef COOPMAT
|
||||||
#extension GL_KHR_cooperative_matrix : enable
|
#extension GL_KHR_cooperative_matrix : enable
|
||||||
#extension GL_KHR_memory_scope_semantics : enable
|
#extension GL_KHR_memory_scope_semantics : enable
|
||||||
|
@ -29,6 +33,10 @@
|
||||||
#define LOAD_VEC_B 1
|
#define LOAD_VEC_B 1
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#if !defined(TO_FLOAT_TYPE)
|
||||||
|
#define TO_FLOAT_TYPE FLOAT_TYPE
|
||||||
|
#endif
|
||||||
|
|
||||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||||
|
@ -202,8 +210,8 @@ void main() {
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifdef COOPMAT
|
#ifdef COOPMAT
|
||||||
coopmat<float16_t, gl_ScopeSubgroup, TM, TK, gl_MatrixUseA> cache_a;
|
coopmat<FLOAT_TYPE, gl_ScopeSubgroup, TM, TK, gl_MatrixUseA> cache_a;
|
||||||
coopmat<float16_t, gl_ScopeSubgroup, TK, TN, gl_MatrixUseB> cache_b;
|
coopmat<FLOAT_TYPE, gl_ScopeSubgroup, TK, TN, gl_MatrixUseB> cache_b;
|
||||||
coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> sums[cms_per_row * cms_per_col];
|
coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> sums[cms_per_row * cms_per_col];
|
||||||
|
|
||||||
[[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) {
|
[[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) {
|
||||||
|
@ -248,6 +256,21 @@ void main() {
|
||||||
buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = FLOAT_TYPE(0.0f);
|
buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = FLOAT_TYPE(0.0f);
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
#elif defined(DATA_A_BF16)
|
||||||
|
#if LOAD_VEC_A == 4
|
||||||
|
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||||
|
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
||||||
|
buf_a[buf_idx ] = TO_FLOAT_TYPE(data_a[idx].x);
|
||||||
|
buf_a[buf_idx + 1] = TO_FLOAT_TYPE(data_a[idx].y);
|
||||||
|
buf_a[buf_idx + 2] = TO_FLOAT_TYPE(data_a[idx].z);
|
||||||
|
buf_a[buf_idx + 3] = TO_FLOAT_TYPE(data_a[idx].w);
|
||||||
|
#else
|
||||||
|
if (ir * BM + loadc_a + l < p.M && block + loadr_a < end_k) {
|
||||||
|
buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = TO_FLOAT_TYPE(data_a[pos_a + (loadc_a + l) * p.stride_a + loadr_a]);
|
||||||
|
} else {
|
||||||
|
buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = TO_FLOAT_TYPE(uint16_t(0));
|
||||||
|
}
|
||||||
|
#endif
|
||||||
#elif defined(DATA_A_Q4_0)
|
#elif defined(DATA_A_Q4_0)
|
||||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 4 * loadr_a;
|
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 4 * loadr_a;
|
||||||
|
@ -695,13 +718,13 @@ void main() {
|
||||||
const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
|
const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
|
||||||
#endif
|
#endif
|
||||||
const uint buf_idx = (loadc_b + l) * SHMEM_STRIDE + loadr_b * LOAD_VEC_B;
|
const uint buf_idx = (loadc_b + l) * SHMEM_STRIDE + loadr_b * LOAD_VEC_B;
|
||||||
buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx].x);
|
buf_b[buf_idx + 0] = TO_FLOAT_TYPE(data_b[idx].x);
|
||||||
buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx].y);
|
buf_b[buf_idx + 1] = TO_FLOAT_TYPE(data_b[idx].y);
|
||||||
buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx].z);
|
buf_b[buf_idx + 2] = TO_FLOAT_TYPE(data_b[idx].z);
|
||||||
buf_b[buf_idx + 3] = FLOAT_TYPE(data_b[idx].w);
|
buf_b[buf_idx + 3] = TO_FLOAT_TYPE(data_b[idx].w);
|
||||||
#elif !MUL_MAT_ID
|
#elif !MUL_MAT_ID
|
||||||
if (ic * BN + loadc_b + l < p.N && block + loadr_b < end_k) {
|
if (ic * BN + loadc_b + l < p.N && block + loadr_b < end_k) {
|
||||||
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]);
|
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = TO_FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]);
|
||||||
} else {
|
} else {
|
||||||
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f);
|
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f);
|
||||||
}
|
}
|
||||||
|
@ -709,7 +732,7 @@ void main() {
|
||||||
const uint row_i = ic * BN + loadc_b + l;
|
const uint row_i = ic * BN + loadc_b + l;
|
||||||
if (row_i < _ne1) {
|
if (row_i < _ne1) {
|
||||||
const u16vec2 row_idx = row_ids[row_i];
|
const u16vec2 row_idx = row_ids[row_i];
|
||||||
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]);
|
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = TO_FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]);
|
||||||
} else {
|
} else {
|
||||||
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f);
|
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f);
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,6 +14,9 @@
|
||||||
#extension GL_EXT_buffer_reference : enable
|
#extension GL_EXT_buffer_reference : enable
|
||||||
#extension GL_KHR_shader_subgroup_ballot : enable
|
#extension GL_KHR_shader_subgroup_ballot : enable
|
||||||
#extension GL_KHR_shader_subgroup_vote : enable
|
#extension GL_KHR_shader_subgroup_vote : enable
|
||||||
|
#ifdef DATA_A_BF16
|
||||||
|
#extension GL_EXT_bfloat16 : enable
|
||||||
|
#endif
|
||||||
|
|
||||||
#include "types.comp"
|
#include "types.comp"
|
||||||
|
|
||||||
|
@ -80,6 +83,12 @@ layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
|
||||||
#define store_scales(a)
|
#define store_scales(a)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#if defined(DATA_A_BF16)
|
||||||
|
#define MAT_TYPE bfloat16_t
|
||||||
|
#else
|
||||||
|
#define MAT_TYPE FLOAT_TYPE
|
||||||
|
#endif
|
||||||
|
|
||||||
#ifdef MUL_MAT_ID
|
#ifdef MUL_MAT_ID
|
||||||
layout (binding = 3) readonly buffer IDS {int data_ids[];};
|
layout (binding = 3) readonly buffer IDS {int data_ids[];};
|
||||||
|
|
||||||
|
@ -271,8 +280,8 @@ void main() {
|
||||||
|
|
||||||
// Manually partial unroll
|
// Manually partial unroll
|
||||||
[[unroll]] for (uint j = 0; j < unroll_count; ++j) {
|
[[unroll]] for (uint j = 0; j < unroll_count; ++j) {
|
||||||
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
||||||
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
|
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
|
||||||
|
|
||||||
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
||||||
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose);
|
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose);
|
||||||
|
@ -286,8 +295,8 @@ void main() {
|
||||||
store_scales(tid);
|
store_scales(tid);
|
||||||
}
|
}
|
||||||
while (block_k < end_k) {
|
while (block_k < end_k) {
|
||||||
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
||||||
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
|
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
|
||||||
|
|
||||||
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
||||||
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose);
|
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose);
|
||||||
|
@ -310,8 +319,8 @@ void main() {
|
||||||
|
|
||||||
// Manually partial unroll
|
// Manually partial unroll
|
||||||
[[unroll]] for (uint j = 0; j < unroll_count; ++j) {
|
[[unroll]] for (uint j = 0; j < unroll_count; ++j) {
|
||||||
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
||||||
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
|
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
|
||||||
|
|
||||||
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
||||||
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose);
|
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose);
|
||||||
|
@ -325,8 +334,8 @@ void main() {
|
||||||
store_scales(tid);
|
store_scales(tid);
|
||||||
}
|
}
|
||||||
while (block_k < end_k) {
|
while (block_k < end_k) {
|
||||||
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
||||||
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
|
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
|
||||||
|
|
||||||
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
||||||
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose);
|
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose);
|
||||||
|
@ -350,8 +359,8 @@ void main() {
|
||||||
|
|
||||||
// Manually partial unroll
|
// Manually partial unroll
|
||||||
[[unroll]] for (uint j = 0; j < unroll_count; ++j) {
|
[[unroll]] for (uint j = 0; j < unroll_count; ++j) {
|
||||||
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
||||||
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
|
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
|
||||||
|
|
||||||
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
||||||
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose);
|
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose);
|
||||||
|
@ -365,8 +374,8 @@ void main() {
|
||||||
store_scales(tid);
|
store_scales(tid);
|
||||||
}
|
}
|
||||||
while (block_k < end_k) {
|
while (block_k < end_k) {
|
||||||
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
||||||
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
|
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
|
||||||
|
|
||||||
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
||||||
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose);
|
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose);
|
||||||
|
@ -405,8 +414,8 @@ void main() {
|
||||||
fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false);
|
fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
||||||
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
|
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
|
||||||
|
|
||||||
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
||||||
#ifdef MUL_MAT_ID
|
#ifdef MUL_MAT_ID
|
||||||
|
|
|
@ -0,0 +1,7 @@
|
||||||
|
#version 460
|
||||||
|
|
||||||
|
#extension GL_EXT_bfloat16 : require
|
||||||
|
|
||||||
|
void main()
|
||||||
|
{
|
||||||
|
}
|
|
@ -33,6 +33,19 @@
|
||||||
#endif
|
#endif
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#if defined(DATA_A_BF16)
|
||||||
|
#define QUANT_K 1
|
||||||
|
#define QUANT_R 1
|
||||||
|
|
||||||
|
#if !defined(LOAD_VEC_A) || LOAD_VEC_A == 1
|
||||||
|
#define A_TYPE uint16_t
|
||||||
|
#elif LOAD_VEC_A == 4
|
||||||
|
#define A_TYPE u16vec4
|
||||||
|
#elif LOAD_VEC_A == 8
|
||||||
|
#error unsupported
|
||||||
|
#endif
|
||||||
|
#endif
|
||||||
|
|
||||||
#define QUANT_K_Q4_0 32
|
#define QUANT_K_Q4_0 32
|
||||||
#define QUANT_R_Q4_0 2
|
#define QUANT_R_Q4_0 2
|
||||||
|
|
||||||
|
@ -1343,4 +1356,18 @@ void init_iq_shmem(uvec3 wgsize)
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
// returns the bfloat value in the low 16b.
|
||||||
|
// See ggml_compute_fp32_to_bf16
|
||||||
|
uint32_t fp32_to_bf16(float f)
|
||||||
|
{
|
||||||
|
uint32_t u = floatBitsToUint(f);
|
||||||
|
u = (u + (0x7fff + ((u >> 16) & 1))) >> 16;
|
||||||
|
return u;
|
||||||
|
}
|
||||||
|
|
||||||
|
float bf16_to_fp32(uint32_t u)
|
||||||
|
{
|
||||||
|
return uintBitsToFloat(u << 16);
|
||||||
|
}
|
||||||
|
|
||||||
#endif // !defined(GGML_TYPES_COMP)
|
#endif // !defined(GGML_TYPES_COMP)
|
||||||
|
|
|
@ -75,7 +75,8 @@ const std::vector<std::string> type_names = {
|
||||||
"iq3_xxs",
|
"iq3_xxs",
|
||||||
"iq3_s",
|
"iq3_s",
|
||||||
"iq4_xs",
|
"iq4_xs",
|
||||||
"iq4_nl"
|
"iq4_nl",
|
||||||
|
"bf16",
|
||||||
};
|
};
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
@ -310,7 +311,6 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
|
||||||
std::string aligned_b_type_f16 = coopmat2 ? "float16_t" : fp16 ? "f16mat2x4" : "f16vec4";
|
std::string aligned_b_type_f16 = coopmat2 ? "float16_t" : fp16 ? "f16mat2x4" : "f16vec4";
|
||||||
|
|
||||||
std::map<std::string, std::string> base_dict = {
|
std::map<std::string, std::string> base_dict = {
|
||||||
{"FLOAT_TYPE", (coopmat2 || fp16) ? "float16_t" : "float"},
|
|
||||||
{"FLOAT_TYPE_VEC2", (coopmat2 || fp16) ? "f16vec2" : "vec2"},
|
{"FLOAT_TYPE_VEC2", (coopmat2 || fp16) ? "f16vec2" : "vec2"},
|
||||||
};
|
};
|
||||||
std::string shader_name = "matmul";
|
std::string shader_name = "matmul";
|
||||||
|
@ -332,12 +332,45 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
|
||||||
|
|
||||||
const std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp";
|
const std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp";
|
||||||
|
|
||||||
// Shaders with f16 B_TYPE
|
auto const &FLOAT_TYPE = [&](const std::string &t) -> std::string {
|
||||||
string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc);
|
if (t == "bf16") {
|
||||||
string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
// scalar path promotes to float
|
||||||
|
if (!coopmat && !coopmat2) {
|
||||||
|
return "float";
|
||||||
|
}
|
||||||
|
return "bfloat16_t";
|
||||||
|
}
|
||||||
|
if (coopmat2 || fp16) {
|
||||||
|
return "float16_t";
|
||||||
|
}
|
||||||
|
return "float";
|
||||||
|
};
|
||||||
|
|
||||||
string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
// Shaders with f16 B_TYPE
|
||||||
string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
|
string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc);
|
||||||
|
string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||||
|
|
||||||
|
string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||||
|
string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
|
||||||
|
|
||||||
|
// bf16
|
||||||
|
{
|
||||||
|
std::string load_vec_a_unaligned = "1";
|
||||||
|
// For aligned matmul loads
|
||||||
|
std::string load_vec_a = coopmat2 ? "1" : "4";
|
||||||
|
|
||||||
|
// scalar path promotes to float
|
||||||
|
std::string to_float_type = (coopmat || coopmat2) ? "uintBitsToBFloat16EXT" : "bf16_to_fp32";
|
||||||
|
|
||||||
|
// If bfloat16 is not supported, then only compile the scalar (promote to fp32) shader
|
||||||
|
#if !defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
|
||||||
|
if (!(coopmat || coopmat2))
|
||||||
|
#endif
|
||||||
|
{
|
||||||
|
string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("bf16")}, {"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||||
|
string_to_spv(shader_name + "_bf16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("bf16")}, {"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for (const auto& tname : type_names) {
|
for (const auto& tname : type_names) {
|
||||||
std::string load_vec_quant = "2";
|
std::string load_vec_quant = "2";
|
||||||
|
@ -346,26 +379,30 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
|
||||||
else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq4_nl"))
|
else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq4_nl"))
|
||||||
load_vec_quant = "4";
|
load_vec_quant = "4";
|
||||||
|
|
||||||
|
if (tname == "bf16") {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
||||||
// For unaligned, load one at a time for f32/f16, or two at a time for quants
|
// For unaligned, load one at a time for f32/f16, or two at a time for quants
|
||||||
std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16") ? "1" : load_vec_quant;
|
std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? "1" : load_vec_quant;
|
||||||
// For aligned matmul loads
|
// For aligned matmul loads
|
||||||
std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16") ? load_vec : load_vec_quant;
|
std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? load_vec : load_vec_quant;
|
||||||
|
|
||||||
// don't generate f32 variants for coopmat2
|
// don't generate f32 variants for coopmat2
|
||||||
if (!coopmat2) {
|
if (!coopmat2) {
|
||||||
string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
|
string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
|
||||||
string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (tname != "f16" && tname != "f32") {
|
if (tname != "f16" && tname != "f32") {
|
||||||
string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
|
string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
|
||||||
string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||||
}
|
}
|
||||||
|
|
||||||
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||||
if (!coopmat && !coopmat2 && !matmul_id && (tname == "q4_0" || tname == "q4_1" || tname == "q5_0" || tname == "q5_1" || tname == "q8_0")) {
|
if (!coopmat && !coopmat2 && !matmul_id && (tname == "q4_0" || tname == "q4_1" || tname == "q5_0" || tname == "q5_1" || tname == "q8_0")) {
|
||||||
string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc);
|
string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc);
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
@ -407,6 +444,7 @@ void process_shaders() {
|
||||||
if (tname == "f32") {
|
if (tname == "f32") {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
if (tname == "bf16") continue;
|
||||||
|
|
||||||
if (tname == "f16") {
|
if (tname == "f16") {
|
||||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
|
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
|
||||||
|
@ -431,12 +469,12 @@ void process_shaders() {
|
||||||
string_to_spv("mul_mat_vec_id_" + tname + "_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
|
string_to_spv("mul_mat_vec_id_" + tname + "_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
|
||||||
|
|
||||||
// Dequant shaders
|
// Dequant shaders
|
||||||
if (tname != "f16") {
|
if (tname != "f16" && tname != "bf16") {
|
||||||
string_to_spv("dequant_" + tname, "dequant_" + tname + ".comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float16_t"}}));
|
string_to_spv("dequant_" + tname, "dequant_" + tname + ".comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float16_t"}}));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!string_ends_with(tname, "_k")) {
|
if (!string_ends_with(tname, "_k")) {
|
||||||
shader = (tname == "f32" || tname == "f16") ? "get_rows.comp" : "get_rows_quant.comp";
|
shader = (tname == "f32" || tname == "f16" || tname == "bf16") ? "get_rows.comp" : "get_rows_quant.comp";
|
||||||
|
|
||||||
if (tname == "f16") {
|
if (tname == "f16") {
|
||||||
string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}));
|
string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}));
|
||||||
|
@ -461,9 +499,11 @@ void process_shaders() {
|
||||||
string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||||
string_to_spv("cpy_f32_f16", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
|
string_to_spv("cpy_f32_f16", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
|
||||||
string_to_spv("cpy_f16_f16", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
|
string_to_spv("cpy_f16_f16", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
|
||||||
|
string_to_spv("cpy_f32_bf16","copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "uint16_t"}, {"DATA_D_BF16", "1"}});
|
||||||
string_to_spv("contig_cpy_f32_f32", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
string_to_spv("contig_cpy_f32_f32", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||||
string_to_spv("contig_cpy_f32_f16", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
|
string_to_spv("contig_cpy_f32_f16", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
|
||||||
string_to_spv("contig_cpy_f16_f16", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
|
string_to_spv("contig_cpy_f16_f16", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
|
||||||
|
string_to_spv("contig_cpy_f32_bf16","contig_copy.comp",{{"A_TYPE", "float"}, {"D_TYPE", "uint16_t"}, {"DATA_D_BF16", "1"}});
|
||||||
|
|
||||||
for (std::string t : {"q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) {
|
for (std::string t : {"q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) {
|
||||||
string_to_spv("cpy_f32_" + t, "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
string_to_spv("cpy_f32_" + t, "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||||
|
|
|
@ -231,6 +231,7 @@ class Keys:
|
||||||
BLOCK_COUNT = "clip.vision.block_count"
|
BLOCK_COUNT = "clip.vision.block_count"
|
||||||
IMAGE_MEAN = "clip.vision.image_mean"
|
IMAGE_MEAN = "clip.vision.image_mean"
|
||||||
IMAGE_STD = "clip.vision.image_std"
|
IMAGE_STD = "clip.vision.image_std"
|
||||||
|
SPATIAL_MERGE_SIZE = "clip.vision.spatial_merge_size"
|
||||||
USE_GELU = "clip.use_gelu"
|
USE_GELU = "clip.use_gelu"
|
||||||
USE_SILU = "clip.use_silu"
|
USE_SILU = "clip.use_silu"
|
||||||
|
|
||||||
|
@ -491,6 +492,7 @@ class MODEL_TENSOR(IntEnum):
|
||||||
V_ENC_FFN_DOWN = auto()
|
V_ENC_FFN_DOWN = auto()
|
||||||
V_PRE_NORM = auto()
|
V_PRE_NORM = auto()
|
||||||
V_POST_NORM = auto()
|
V_POST_NORM = auto()
|
||||||
|
V_MM_INP_NORM = auto()
|
||||||
V_MM_INP_PROJ = auto() # gemma3
|
V_MM_INP_PROJ = auto() # gemma3
|
||||||
V_MM_SOFT_EMB_NORM = auto() # gemma3
|
V_MM_SOFT_EMB_NORM = auto() # gemma3
|
||||||
V_RESMPL_POS_EMBD_K = auto() # minicpmv
|
V_RESMPL_POS_EMBD_K = auto() # minicpmv
|
||||||
|
@ -505,6 +507,7 @@ class MODEL_TENSOR(IntEnum):
|
||||||
V_RESMPL_PROJ = auto() # minicpmv
|
V_RESMPL_PROJ = auto() # minicpmv
|
||||||
V_RESMPL_QUERY = auto() # minicpmv
|
V_RESMPL_QUERY = auto() # minicpmv
|
||||||
V_TOK_EMBD_IMG_BREAK = auto() # pixtral
|
V_TOK_EMBD_IMG_BREAK = auto() # pixtral
|
||||||
|
V_MM_PATCH_MERGER = auto() # mistral small 3.1
|
||||||
|
|
||||||
|
|
||||||
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||||
|
@ -747,6 +750,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||||
MODEL_TENSOR.V_PRE_NORM: "v.pre_ln",
|
MODEL_TENSOR.V_PRE_NORM: "v.pre_ln",
|
||||||
MODEL_TENSOR.V_POST_NORM: "v.post_ln",
|
MODEL_TENSOR.V_POST_NORM: "v.post_ln",
|
||||||
MODEL_TENSOR.V_MM_INP_PROJ: "mm.input_projection",
|
MODEL_TENSOR.V_MM_INP_PROJ: "mm.input_projection",
|
||||||
|
MODEL_TENSOR.V_MM_INP_NORM: "mm.input_norm",
|
||||||
MODEL_TENSOR.V_MM_SOFT_EMB_NORM: "mm.soft_emb_norm",
|
MODEL_TENSOR.V_MM_SOFT_EMB_NORM: "mm.soft_emb_norm",
|
||||||
MODEL_TENSOR.V_RESMPL_POS_EMBD_K: "resampler.pos_embd_k",
|
MODEL_TENSOR.V_RESMPL_POS_EMBD_K: "resampler.pos_embd_k",
|
||||||
MODEL_TENSOR.V_RESMPL_ATTN_Q: "resampler.attn.q",
|
MODEL_TENSOR.V_RESMPL_ATTN_Q: "resampler.attn.q",
|
||||||
|
@ -760,6 +764,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||||
MODEL_TENSOR.V_RESMPL_PROJ: "resampler.proj",
|
MODEL_TENSOR.V_RESMPL_PROJ: "resampler.proj",
|
||||||
MODEL_TENSOR.V_RESMPL_QUERY: "resampler.query",
|
MODEL_TENSOR.V_RESMPL_QUERY: "resampler.query",
|
||||||
MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK: "v.token_embd.img_break", # pixtral
|
MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK: "v.token_embd.img_break", # pixtral
|
||||||
|
MODEL_TENSOR.V_MM_PATCH_MERGER: "mm.patch_merger", # mistral small 3.1
|
||||||
}
|
}
|
||||||
|
|
||||||
MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||||
|
@ -783,6 +788,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||||
MODEL_TENSOR.V_PRE_NORM,
|
MODEL_TENSOR.V_PRE_NORM,
|
||||||
MODEL_TENSOR.V_POST_NORM,
|
MODEL_TENSOR.V_POST_NORM,
|
||||||
MODEL_TENSOR.V_MM_INP_PROJ,
|
MODEL_TENSOR.V_MM_INP_PROJ,
|
||||||
|
MODEL_TENSOR.V_MM_INP_NORM,
|
||||||
MODEL_TENSOR.V_MM_SOFT_EMB_NORM,
|
MODEL_TENSOR.V_MM_SOFT_EMB_NORM,
|
||||||
MODEL_TENSOR.V_RESMPL_POS_EMBD_K,
|
MODEL_TENSOR.V_RESMPL_POS_EMBD_K,
|
||||||
MODEL_TENSOR.V_RESMPL_ATTN_Q,
|
MODEL_TENSOR.V_RESMPL_ATTN_Q,
|
||||||
|
@ -796,6 +802,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||||
MODEL_TENSOR.V_RESMPL_PROJ,
|
MODEL_TENSOR.V_RESMPL_PROJ,
|
||||||
MODEL_TENSOR.V_RESMPL_QUERY,
|
MODEL_TENSOR.V_RESMPL_QUERY,
|
||||||
MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK,
|
MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK,
|
||||||
|
MODEL_TENSOR.V_MM_PATCH_MERGER,
|
||||||
],
|
],
|
||||||
MODEL_ARCH.LLAMA: [
|
MODEL_ARCH.LLAMA: [
|
||||||
MODEL_TENSOR.TOKEN_EMBD,
|
MODEL_TENSOR.TOKEN_EMBD,
|
||||||
|
|
|
@ -972,6 +972,9 @@ class GGUFWriter:
|
||||||
def add_vision_image_std(self, values: Sequence[float]) -> None:
|
def add_vision_image_std(self, values: Sequence[float]) -> None:
|
||||||
self.add_array(Keys.ClipVision.IMAGE_STD, values)
|
self.add_array(Keys.ClipVision.IMAGE_STD, values)
|
||||||
|
|
||||||
|
def add_vision_spatial_merge_size(self, value: int) -> None:
|
||||||
|
self.add_uint32(Keys.ClipVision.SPATIAL_MERGE_SIZE, value)
|
||||||
|
|
||||||
def add_vision_use_gelu(self, value: bool) -> None:
|
def add_vision_use_gelu(self, value: bool) -> None:
|
||||||
self.add_bool(Keys.ClipVision.USE_GELU, value)
|
self.add_bool(Keys.ClipVision.USE_GELU, value)
|
||||||
|
|
||||||
|
|
|
@ -1001,6 +1001,10 @@ class TensorNameMap:
|
||||||
"multi_modal_projector.mm_input_projection",
|
"multi_modal_projector.mm_input_projection",
|
||||||
),
|
),
|
||||||
|
|
||||||
|
MODEL_TENSOR.V_MM_INP_NORM: (
|
||||||
|
"multi_modal_projector.norm",
|
||||||
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.V_MM_SOFT_EMB_NORM: (
|
MODEL_TENSOR.V_MM_SOFT_EMB_NORM: (
|
||||||
"multi_modal_projector.mm_soft_emb_norm",
|
"multi_modal_projector.mm_soft_emb_norm",
|
||||||
),
|
),
|
||||||
|
@ -1052,6 +1056,10 @@ class TensorNameMap:
|
||||||
MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK: (
|
MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK: (
|
||||||
"v.token_embd.img_break", # for pixtral, this is a generated vector
|
"v.token_embd.img_break", # for pixtral, this is a generated vector
|
||||||
),
|
),
|
||||||
|
|
||||||
|
MODEL_TENSOR.V_MM_PATCH_MERGER: (
|
||||||
|
"multi_modal_projector.patch_merger.merging_layer", # mistral small 3.1
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
# architecture-specific block mappings
|
# architecture-specific block mappings
|
||||||
|
|
|
@ -5185,7 +5185,7 @@ def show_gui():
|
||||||
ctk.CTkButton(tabs , text = "Update", fg_color="#9900cc", hover_color="#aa11dd", command = display_updates, width=90, height = 35 ).grid(row=1,column=0, stick="sw", padx= 5, pady=5)
|
ctk.CTkButton(tabs , text = "Update", fg_color="#9900cc", hover_color="#aa11dd", command = display_updates, width=90, height = 35 ).grid(row=1,column=0, stick="sw", padx= 5, pady=5)
|
||||||
ctk.CTkButton(tabs , text = "Save", fg_color="#084a66", hover_color="#085a88", command = save_config_gui, width=60, height = 35 ).grid(row=1,column=1, stick="sw", padx= 5, pady=5)
|
ctk.CTkButton(tabs , text = "Save", fg_color="#084a66", hover_color="#085a88", command = save_config_gui, width=60, height = 35 ).grid(row=1,column=1, stick="sw", padx= 5, pady=5)
|
||||||
ctk.CTkButton(tabs , text = "Load", fg_color="#084a66", hover_color="#085a88", command = load_config_gui, width=60, height = 35 ).grid(row=1,column=1, stick="sw", padx= 70, pady=5)
|
ctk.CTkButton(tabs , text = "Load", fg_color="#084a66", hover_color="#085a88", command = load_config_gui, width=60, height = 35 ).grid(row=1,column=1, stick="sw", padx= 70, pady=5)
|
||||||
ctk.CTkButton(tabs , text = "Help", fg_color="#992222", hover_color="#bb3333", command = display_help, width=50, height = 35 ).grid(row=1,column=1, stick="sw", padx= 135, pady=5)
|
ctk.CTkButton(tabs , text = "Help (Find Models)", fg_color="#992222", hover_color="#bb3333", command = display_help, width=100, height = 35 ).grid(row=1,column=1, stick="sw", padx= 135, pady=5)
|
||||||
|
|
||||||
# start a thread that tries to get actual gpu names and layer counts
|
# start a thread that tries to get actual gpu names and layer counts
|
||||||
gpuinfo_thread = threading.Thread(target=auto_set_backend_gui)
|
gpuinfo_thread = threading.Thread(target=auto_set_backend_gui)
|
||||||
|
|
|
@ -454,7 +454,7 @@ int32_t llm_chat_apply_template(
|
||||||
ss << "<|" << role << "|>" << "\n" << message->content;
|
ss << "<|" << role << "|>" << "\n" << message->content;
|
||||||
}
|
}
|
||||||
if (add_ass) {
|
if (add_ass) {
|
||||||
ss << "<|assistant|>";
|
ss << "<|assistant|>\n";
|
||||||
}
|
}
|
||||||
} else if (tmpl == LLM_CHAT_TEMPLATE_MINICPM) {
|
} else if (tmpl == LLM_CHAT_TEMPLATE_MINICPM) {
|
||||||
// MiniCPM-3B-OpenHermes-2.5-v2-GGUF
|
// MiniCPM-3B-OpenHermes-2.5-v2-GGUF
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue