diff --git a/common/arg.cpp b/common/arg.cpp index 411d3fe7a..3edde43fb 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #include @@ -2377,20 +2378,35 @@ common_params_context common_params_parser_init(common_params & params, llama_ex } throw std::invalid_argument("unknown buffer type"); } - // FIXME: this leaks memory - params.tensor_buft_overrides.push_back({strdup(tensor_name.c_str()), buft_list.at(buffer_type)}); + // keep strings alive and avoid leaking memory by storing them in a static vector + static std::list buft_overrides; + buft_overrides.push_back(tensor_name); + params.tensor_buft_overrides.push_back({buft_overrides.back().c_str(), buft_list.at(buffer_type)}); } } )); add_opt(common_arg( - {"--cpu-moe"}, - "use CPU for Mixture of Experts (MoE) weights", + {"--cpu-moe", "-cmoe"}, + "keep all Mixture of Experts (MoE) weights in the CPU", [](common_params & params) { - params.tensor_buft_overrides.push_back({"\\.ffn_up_exps\\.weight$", ggml_backend_cpu_buffer_type()}); - params.tensor_buft_overrides.push_back({"\\.ffn_down_exps\\.weight$", ggml_backend_cpu_buffer_type()}); - params.tensor_buft_overrides.push_back({"\\.ffn_gate_exps\\.weight$", ggml_backend_cpu_buffer_type()}); + params.tensor_buft_overrides.push_back({"\\.ffn_(up|down|gate)_exps", ggml_backend_cpu_buffer_type()}); } ).set_env("LLAMA_ARG_CPU_MOE")); + add_opt(common_arg( + {"--n-cpu-moe", "-ncmoe"}, "N", + "keep the Mixture of Experts (MoE) weights of the first N layers in the CPU", + [](common_params & params, int value) { + if (value < 0) { + throw std::invalid_argument("invalid value"); + } + for (int i = 0; i < value; ++i) { + // keep strings alive and avoid leaking memory by storing them in a static vector + static std::list buft_overrides; + buft_overrides.push_back(string_format("blk\\.%d\\.ffn_(up|down|gate)_exps", i)); + params.tensor_buft_overrides.push_back({buft_overrides.back().c_str(), ggml_backend_cpu_buffer_type()}); + } + } + ).set_env("LLAMA_ARG_N_CPU_MOE")); add_opt(common_arg( {"-ngl", "--gpu-layers", "--n-gpu-layers"}, "N", "number of layers to store in VRAM", @@ -2651,10 +2667,10 @@ common_params_context common_params_parser_init(common_params & params, llama_ex ).set_examples({LLAMA_EXAMPLE_IMATRIX})); add_opt(common_arg( {"--output-format"}, "{gguf,dat}", - string_format("output format for imatrix file (default: %s)", params.imat_dat ? "dat" : "gguf"), + string_format("output format for imatrix file (default: %s)", params.imat_dat > 0 ? "dat" : "gguf"), [](common_params & params, const std::string & value) { - /**/ if (value == "gguf") { params.imat_dat = false; } - else if (value == "dat") { params.imat_dat = true; } + /**/ if (value == "gguf") { params.imat_dat = -1; } + else if (value == "dat") { params.imat_dat = 1; } else { throw std::invalid_argument("invalid output format"); } } ).set_examples({LLAMA_EXAMPLE_IMATRIX})); diff --git a/common/common.h b/common/common.h index 6f32eadad..cabc106a4 100644 --- a/common/common.h +++ b/common/common.h @@ -435,7 +435,7 @@ struct common_params { int32_t n_out_freq = 10; // output the imatrix every n_out_freq iterations int32_t n_save_freq = 0; // save the imatrix every n_save_freq iterations int32_t i_chunk = 0; // start processing from this chunk - bool imat_dat = false; // whether the legacy imatrix.dat format should be output + int8_t imat_dat = 0; // whether the legacy imatrix.dat format should be output (gguf <= 0 < dat) bool process_output = false; // collect data for the output tensor bool compute_ppl = true; // whether to compute perplexity diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 9303a0476..a215f4ed7 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -678,6 +678,9 @@ class TextModel(ModelBase): if chkhsh == "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2": # ref: https://huggingface.co/THUDM/glm-4-9b-hf res = "glm4" + if chkhsh == "9ca2dd618e8afaf09731a7cf6e2105b373ba6a1821559f258b272fe83e6eb902": + # ref: https://huggingface.co/zai-org/GLM-4.5-Air + res = "glm4" if chkhsh == "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35": # ref: https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0 res = "minerva-7b" @@ -6696,6 +6699,139 @@ class Glm4Model(TextModel): return super().modify_tensors(data_torch, name, bid) +@ModelBase.register("Glm4MoeForCausalLM") +class Glm4MoeModel(TextModel): + model_arch = gguf.MODEL_ARCH.GLM4_MOE + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # GLM4_MOE has num_hidden_layers + 1 actual layers (including NextN layer) + self.block_count = self.hparams["num_hidden_layers"] + self.hparams.get("num_nextn_predict_layers", 0) + self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) + + def set_vocab(self): + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained(self.dir_model) + special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True) + tokens, toktypes, tokpre = self.get_vocab_base() + self.gguf_writer.add_tokenizer_model("gpt2") + self.gguf_writer.add_tokenizer_pre(tokpre) + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_types(toktypes) + + # Special tokens + # Note: Using <|endoftext|> (151329) for eot causes endless generation + special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["[gMASK]"]) # 151331 + special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"]) # 151336 + special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"]) # 151329 + special_vocab._set_special_token("eom", tokenizer.get_added_vocab()["<|observation|>"]) # 151338 + + # Patch broken chat template + if isinstance(special_vocab.chat_template, str) and "visible_text(m.content).endswith" in special_vocab.chat_template: + special_vocab.chat_template = special_vocab.chat_template.replace( + """{{ visible_text(m.content) }}\n{{- '/nothink' if (enable_thinking is defined and not enable_thinking and not visible_text(m.content).endswith("/nothink")) else '' -}}""", + """{% set content = visible_text(m.content) %}{{ content }}\n{{- '/nothink' if (enable_thinking is defined and not enable_thinking and not content.endswith("/nothink")) else '' -}}""") + + special_vocab.add_to_gguf(self.gguf_writer) + + def set_gguf_parameters(self): + super().set_gguf_parameters() + if (rope_dim := self.hparams.get("head_dim")) is None: + rope_dim = ( + self.hparams["hidden_size"] // self.hparams["num_attention_heads"] + ) + self.gguf_writer.add_rope_dimension_count( + int(rope_dim * self.hparams.get("partial_rotary_factor", 0.5)) + ) + + # MoE parameters - Use only routed expert count (shared experts handled separately) + if (n_routed_experts := self.hparams.get("n_routed_experts")) is not None: + self.gguf_writer.add_expert_count(n_routed_experts) + if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None: + self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size) + if (n_shared_experts := self.hparams.get("n_shared_experts")) is not None: + self.gguf_writer.add_expert_shared_count(n_shared_experts) + if (first_k_dense_replace := self.hparams.get("first_k_dense_replace")) is not None: + self.gguf_writer.add_leading_dense_block_count(first_k_dense_replace) + + # Expert gating function (sigmoid for GLM4_MOE) + self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID) + + # Routed scaling factor + if (routed_scaling_factor := self.hparams.get("routed_scaling_factor")) is not None: + self.gguf_writer.add_expert_weights_scale(routed_scaling_factor) + + # Normalise topk probabilities + if (norm_topk_prob := self.hparams.get("norm_topk_prob")) is not None: + self.gguf_writer.add_expert_weights_norm(norm_topk_prob) + + # NextN/MTP prediction layers + if (num_nextn_predict_layers := self.hparams.get("num_nextn_predict_layers")) is not None: + self.gguf_writer.add_nextn_predict_layers(num_nextn_predict_layers) + + _experts: list[dict[str, Tensor]] | None = None + + def modify_tensors( + self, data_torch: Tensor, name: str, bid: int | None + ) -> Iterable[tuple[str, Tensor]]: + if name.startswith("model.visual."): # ignore visual part + return [] + elif name.startswith("model.language_model."): + name = name.replace("language_model.", "") # for multimodal variants + + # Handle main token embedding (but not layer-specific NextN embeddings) + if name == "model.embed_tokens.weight" and ".layers." not in name: + return [(self.map_tensor_name("token_embd.weight"), data_torch)] + + # Handle routed experts + if name.find("mlp.experts") != -1: + n_experts = self.hparams["n_routed_experts"] + assert bid is not None + + if self._experts is None: + self._experts = [{} for _ in range(self.block_count)] + + self._experts[bid][name] = data_torch + + if len(self._experts[bid]) >= n_experts * 3: + tensors: list[tuple[str, Tensor]] = [] + + # merge the experts into a single 3d tensor + for w_name in ["down_proj", "gate_proj", "up_proj"]: + datas: list[Tensor] = [] + + for xid in range(n_experts): + ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight" + datas.append(self._experts[bid][ename]) + del self._experts[bid][ename] + + data_torch = torch.stack(datas, dim=0) + + merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight" + + new_name = self.map_tensor_name(merged_name) + tensors.append((new_name, data_torch)) + return tensors + else: + return [] + + if name.endswith("e_score_correction_bias"): + name = name.replace("e_score_correction_bias", "e_score_correction.bias") + + new_name = self.map_tensor_name(name) + + return [(new_name, data_torch)] + + def prepare_tensors(self): + super().prepare_tensors() + if self._experts is not None: + # flatten `list[dict[str, Tensor]]` into `list[str]` + experts = [k for d in self._experts for k in d.keys()] + if len(experts) > 0: + raise ValueError(f"Unprocessed experts: {experts}") + + @ModelBase.register("GlmForCausalLM", "ChatGLMModel", "ChatGLMForConditionalGeneration") class ChatGLMModel(TextModel): model_arch = gguf.MODEL_ARCH.CHATGLM diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py index 226805f1e..575e05e19 100755 --- a/convert_hf_to_gguf_update.py +++ b/convert_hf_to_gguf_update.py @@ -147,6 +147,7 @@ pre_computed_hashes = [ {"name": "chatglm-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-chat", "chkhsh": "b6e8e1518dc4305be2fe39c313ed643381c4da5db34a98f6a04c093f8afbe99b"}, {"name": "chatglm-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-chat", "chkhsh": "81d72c7348a9f0ebe86f23298d37debe0a5e71149e29bd283904c02262b27516"}, {"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-hf", "chkhsh": "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2"}, + {"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/zai-org/GLM-4.5-Air", "chkhsh": "9ca2dd618e8afaf09731a7cf6e2105b373ba6a1821559f258b272fe83e6eb902"}, {"name": "minerva-7b", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0", "chkhsh": "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35"}, {"name": "hunyuan", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Hunyuan-A13B-Instruct", "chkhsh": "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664"}, {"name": "hunyuan-dense", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Hunyuan-4B-Instruct", "chkhsh": "bba3b3366b646dbdded5dbc42d59598b849371afc42f7beafa914afaa5b70aa6"}, diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp index 7814c1895..295cb91fb 100644 --- a/ggml/src/ggml-backend-reg.cpp +++ b/ggml/src/ggml-backend-reg.cpp @@ -499,6 +499,9 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent, std::vector search_paths; if (user_search_path == nullptr) { +#ifdef GGML_BACKEND_DIR + search_paths.push_back(fs::u8path(GGML_BACKEND_DIR)); +#endif // default search paths: executable directory, current directory search_paths.push_back(get_executable_path()); search_paths.push_back(fs::current_path()); diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index c5abc6934..91411d9c0 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -1,34 +1,41 @@ +/* + WebGPU backend implementation. + Note: Use ClangFormat to format this file. +*/ + #include "ggml-webgpu.h" +#include "ggml-backend-impl.h" +#include "ggml-impl.h" +#include "ggml-wgsl-shaders.hpp" + #include -#include "ggml-impl.h" -#include "ggml-backend-impl.h" - -#include "ggml-wgsl-shaders.hpp" - +#include #include #include #include +#include #include #ifdef GGML_WEBGPU_DEBUG -#define WEBGPU_LOG_DEBUG(msg) std::cout << msg << std::endl +# define WEBGPU_LOG_DEBUG(msg) std::cout << msg << std::endl #else -#define WEBGPU_LOG_DEBUG(msg) ((void) 0) -#endif // GGML_WEBGPU_DEBUG +# define WEBGPU_LOG_DEBUG(msg) ((void) 0) +#endif // GGML_WEBGPU_DEBUG /* Constants */ -#define WEBGPU_MUL_MAT_WG_SIZE 64 -#define WEBGPU_MUL_MAT_PARAMS_SIZE (13 * sizeof(uint32_t)) // M, N, K, batch sizes, broadcasts -#define WEBGPU_CPY_PARAMS_SIZE (15 * sizeof(uint32_t)) // strides and offsets -#define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4 +#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 16 +#define WEBGPU_MUL_MAT_WG_SIZE 64 +#define WEBGPU_NUM_PARAM_BUFS 100 +#define WEBGPU_PARAMS_BUF_SIZE_BYTES 256 +#define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4 /* End Constants */ // This is a "fake" base pointer, since WebGPU buffers do not have pointers to their locations. -static void * const webgpu_ptr_base = (void *)(uintptr_t) 0x1000; // NOLINT +static void * const webgpu_ptr_base = (void *) (uintptr_t) 0x1000; // NOLINT // Always returns the base offset of a tensor, regardless of views. static uint64_t webgpu_tensor_offset(const ggml_tensor * tensor) { @@ -40,100 +47,172 @@ static uint64_t webgpu_tensor_offset(const ggml_tensor * tensor) { /* Struct definitions */ +// Forward reference +static void ggml_webgpu_create_buffer(wgpu::Device & device, + wgpu::Buffer & buffer, + size_t size, + wgpu::BufferUsage usage, + const char * label); + +struct webgpu_param_bufs { + wgpu::Buffer host_buf; + wgpu::Buffer dev_buf; +}; + +// Holds a pool of parameter buffers for WebGPU operations +struct webgpu_param_buf_pool { + std::vector free; + + std::mutex mutex; + + std::condition_variable cv; + + void init(wgpu::Device device) { + for (int i = 0; i < WEBGPU_NUM_PARAM_BUFS; i++) { + wgpu::Buffer host_buf; + wgpu::Buffer dev_buf; + ggml_webgpu_create_buffer(device, + host_buf, + WEBGPU_PARAMS_BUF_SIZE_BYTES, + wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite, + "ggml_webgpu_host_params_buf"); + ggml_webgpu_create_buffer(device, + dev_buf, + WEBGPU_PARAMS_BUF_SIZE_BYTES, + wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform, + "ggml_webgpu_dev_params_buf"); + free.push_back({ host_buf, dev_buf }); + } + } + + webgpu_param_bufs alloc_bufs() { + std::unique_lock lock(mutex); + cv.wait(lock, [this] { return !free.empty(); }); + webgpu_param_bufs bufs = free.back(); + free.pop_back(); + return bufs; + } + + void free_bufs(std::vector bufs) { + std::lock_guard lock(mutex); + free.insert(free.end(), bufs.begin(), bufs.end()); + cv.notify_all(); + } + + void cleanup() { + std::lock_guard lock(mutex); + for (auto & bufs : free) { + bufs.host_buf.Destroy(); + bufs.dev_buf.Destroy(); + } + free.clear(); + } +}; + // All the base objects needed to run operations on a WebGPU device struct webgpu_context_struct { wgpu::Instance instance; - wgpu::Adapter adapter; - wgpu::Device device; - wgpu::Queue queue; - wgpu::Limits limits; - wgpu::SupportedFeatures features; + wgpu::Adapter adapter; + wgpu::Device device; + wgpu::Queue queue; + wgpu::Limits limits; - std::mutex mutex; - bool device_initialized = false; + std::recursive_mutex mutex; + std::mutex get_tensor_mutex; + std::mutex init_mutex; + + bool device_init = false; + + webgpu_param_buf_pool param_buf_pool; - // pipelines and parameter buffers - // TODO: reuse params buffers for different pipelines when possible wgpu::ComputePipeline memset_pipeline; - wgpu::Buffer memset_params_dev_buf; - wgpu::Buffer memset_params_host_buf; wgpu::ComputePipeline mul_mat_pipeline; - wgpu::Buffer mul_mat_params_dev_buf; - wgpu::Buffer mul_mat_params_host_buf; wgpu::ComputePipeline cpy_pipeline; - wgpu::Buffer cpy_params_dev_buf; - wgpu::Buffer cpy_params_host_buf; size_t memset_bytes_per_thread; // Staging buffer for reading data from the GPU wgpu::Buffer get_tensor_staging_buf; + + // Command buffers which need to be submitted + std::vector staged_command_bufs; + + // Parameter buffers associated with the staged command buffers + std::vector staged_param_bufs; }; typedef std::shared_ptr webgpu_context; struct ggml_backend_webgpu_reg_context { webgpu_context webgpu_ctx; - - size_t device_count; - const char * name; + size_t device_count; + const char * name; }; struct ggml_backend_webgpu_device_context { webgpu_context webgpu_ctx; - - std::string device_name; - std::string device_desc; + std::string device_name; + std::string device_desc; }; struct ggml_backend_webgpu_context { webgpu_context webgpu_ctx; - - std::string name; + std::string name; }; struct ggml_backend_webgpu_buffer_context { webgpu_context webgpu_ctx; - - wgpu::Buffer buffer; + wgpu::Buffer buffer; ggml_backend_webgpu_buffer_context(webgpu_context ctx, wgpu::Buffer buf) : - webgpu_ctx(ctx), buffer(buf) { - } + webgpu_ctx(std::move(ctx)), + buffer(std::move(buf)) {} }; /* End struct definitions */ /* WebGPU object initializations */ -static void ggml_webgpu_create_pipeline(wgpu::Device &device, wgpu::ComputePipeline &pipeline, const char * shader_code, const char * label, const std::vector &constants = {}) { +static void ggml_webgpu_create_pipeline(wgpu::Device & device, + wgpu::ComputePipeline & pipeline, + const char * shader_code, + const char * label, + const std::vector & constants = {}) { WEBGPU_LOG_DEBUG("ggml_webgpu_create_pipeline()"); + wgpu::ShaderSourceWGSL shader_source; shader_source.code = shader_code; + wgpu::ShaderModuleDescriptor shader_desc; shader_desc.nextInChain = &shader_source; + wgpu::ShaderModule shader_module = device.CreateShaderModule(&shader_desc); wgpu::ComputePipelineDescriptor pipeline_desc; - pipeline_desc.label = label; - pipeline_desc.compute.module = shader_module; - pipeline_desc.compute.entryPoint = "main"; // Entry point in the WGSL code - pipeline_desc.layout = nullptr; // nullptr means auto layout + pipeline_desc.label = label; + pipeline_desc.compute.module = shader_module; + pipeline_desc.compute.entryPoint = "main"; // Entry point in the WGSL code + pipeline_desc.layout = nullptr; // nullptr means auto layout if (constants.size() > 0) { - pipeline_desc.compute.constants = constants.data(); + pipeline_desc.compute.constants = constants.data(); pipeline_desc.compute.constantCount = constants.size(); } pipeline = device.CreateComputePipeline(&pipeline_desc); } -static void ggml_webgpu_create_buffer(wgpu::Device &device, wgpu::Buffer &buffer, size_t size, wgpu::BufferUsage usage, const char* label) { +static void ggml_webgpu_create_buffer(wgpu::Device & device, + wgpu::Buffer & buffer, + size_t size, + wgpu::BufferUsage usage, + const char * label) { WEBGPU_LOG_DEBUG("ggml_webgpu_create_buffer()"); wgpu::BufferDescriptor buffer_desc; - buffer_desc.size = size; - buffer_desc.usage = usage; - buffer_desc.label = label; + buffer_desc.size = size; + buffer_desc.usage = usage; + buffer_desc.label = label; buffer_desc.mappedAtCreation = false; + // TODO: error handling buffer = device.CreateBuffer(&buffer_desc); } @@ -142,75 +221,133 @@ static void ggml_webgpu_create_buffer(wgpu::Device &device, wgpu::Buffer &buffer /** WebGPU Actions */ -static void ggml_backend_webgpu_map_buffer(webgpu_context ctx, wgpu::Buffer buffer, wgpu::MapMode mode, size_t offset, size_t size) { - ctx->instance.WaitAny(buffer.MapAsync( - mode, offset, size, wgpu::CallbackMode::WaitAnyOnly, - [](wgpu::MapAsyncStatus status, wgpu::StringView message) { - if (status != wgpu::MapAsyncStatus::Success) { - GGML_LOG_ERROR("ggml_webgpu: Failed to map buffer: %s\n", message.data); - } - }), - UINT64_MAX - ); +static void ggml_backend_webgpu_wait_on_submission(webgpu_context & ctx) { + // Wait for the queue to finish processing all commands + ctx->instance.WaitAny(ctx->queue.OnSubmittedWorkDone( + wgpu::CallbackMode::AllowSpontaneous, + [](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) { + if (status != wgpu::QueueWorkDoneStatus::Success) { + GGML_LOG_ERROR("ggml_webgpu: Failed to wait on queue: %s\n", message.data); + } + }), + UINT64_MAX); } -static void ggml_backend_webgpu_buffer_memset(webgpu_context ctx, wgpu::Buffer buf, uint32_t value, size_t offset, size_t size) { - std::lock_guard lock(ctx->mutex); - wgpu::Device device = ctx->device; +static void ggml_backend_webgpu_submit_queue(webgpu_context & ctx) { + std::lock_guard lock(ctx->mutex); + ctx->queue.Submit(ctx->staged_command_bufs.size(), ctx->staged_command_bufs.data()); + ctx->staged_command_bufs.clear(); + std::vector staged_param_bufs = std::move(ctx->staged_param_bufs); + // Free the staged parameter buffers once the submission completes + ctx->queue.OnSubmittedWorkDone( + wgpu::CallbackMode::AllowSpontaneous, + [ctx, staged_param_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) { + if (status != wgpu::QueueWorkDoneStatus::Success) { + GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", message.data); + } + // Free the staged parameter buffers + ctx->param_buf_pool.free_bufs(staged_param_bufs); + }); +} - // map the host parameters buffer - ggml_backend_webgpu_map_buffer(ctx, ctx->memset_params_host_buf, wgpu::MapMode::Write, 0, ctx->memset_params_host_buf.GetSize()); - uint32_t * params = (uint32_t *) ctx->memset_params_host_buf.GetMappedRange(); +static void ggml_backend_webgpu_map_buffer(webgpu_context & ctx, + wgpu::Buffer & buffer, + wgpu::MapMode mode, + size_t offset, + size_t size) { + ctx->instance.WaitAny(buffer.MapAsync(mode, + offset, + size, + wgpu::CallbackMode::AllowSpontaneous, + [](wgpu::MapAsyncStatus status, wgpu::StringView message) { + if (status != wgpu::MapAsyncStatus::Success) { + GGML_LOG_ERROR("ggml_webgpu: Failed to map buffer: %s\n", + message.data); + } + }), + UINT64_MAX); +} - params[0] = (uint32_t)offset; - params[1] = (uint32_t)size; - params[2] = value; - ctx->memset_params_host_buf.Unmap(); +static void ggml_backend_webgpu_build_and_enqueue(webgpu_context & ctx, + wgpu::ComputePipeline & pipeline, + std::vector params, + std::vector bind_group_entries, + uint32_t wg_x, + bool submit_imm = false) { + webgpu_param_bufs params_bufs = ctx->param_buf_pool.alloc_bufs(); - wgpu::BindGroupEntry entries[2]; - entries[0].binding = 0; // binding for the buffer to memset - entries[0].buffer = buf; - entries[0].offset = 0; - entries[0].size = buf.GetSize(); - entries[1].binding = 1; // binding for the parameters - entries[1].buffer = ctx->memset_params_dev_buf; - entries[1].offset = 0; - entries[1].size = ctx->memset_params_dev_buf.GetSize(); + ggml_backend_webgpu_map_buffer(ctx, params_bufs.host_buf, wgpu::MapMode::Write, 0, params_bufs.host_buf.GetSize()); + uint32_t * _params = (uint32_t *) params_bufs.host_buf.GetMappedRange(); + for (size_t i = 0; i < params.size(); i++) { + _params[i] = params[i]; + }; + + params_bufs.host_buf.Unmap(); + + uint32_t params_bufs_binding_num = bind_group_entries.size(); + bind_group_entries.push_back({ .binding = params_bufs_binding_num, + .buffer = params_bufs.dev_buf, + .offset = 0, + .size = params_bufs.dev_buf.GetSize() }); wgpu::BindGroupDescriptor bind_group_desc; - bind_group_desc.layout = ctx->memset_pipeline.GetBindGroupLayout(0); - bind_group_desc.entryCount = 2; - bind_group_desc.label = "ggml_memset"; - bind_group_desc.entries = entries; - wgpu::BindGroup bind_group = device.CreateBindGroup(&bind_group_desc); + bind_group_desc.layout = pipeline.GetBindGroupLayout(0); + bind_group_desc.entryCount = bind_group_entries.size(); + bind_group_desc.entries = bind_group_entries.data(); + wgpu::BindGroup bind_group = ctx->device.CreateBindGroup(&bind_group_desc); - wgpu::CommandEncoder encoder = device.CreateCommandEncoder(); - encoder.CopyBufferToBuffer( - ctx->memset_params_host_buf, 0, - ctx->memset_params_dev_buf, 0, - ctx->memset_params_dev_buf.GetSize() - ); + wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder(); + encoder.CopyBufferToBuffer(params_bufs.host_buf, 0, params_bufs.dev_buf, 0, params_bufs.dev_buf.GetSize()); wgpu::ComputePassEncoder pass = encoder.BeginComputePass(); - pass.SetPipeline(ctx->memset_pipeline); + pass.SetPipeline(pipeline); pass.SetBindGroup(0, bind_group); - size_t bytes_per_wg = ctx->limits.maxComputeWorkgroupSizeX * ctx->memset_bytes_per_thread; - pass.DispatchWorkgroups(((size + 3) + bytes_per_wg - 1) / bytes_per_wg, 1, 1); + pass.DispatchWorkgroups(wg_x, 1, 1); pass.End(); wgpu::CommandBuffer commands = encoder.Finish(); - - ctx->queue.Submit(1, &commands); + if (submit_imm) { + // Submit immediately + ctx->queue.Submit(1, &commands); + ctx->queue.OnSubmittedWorkDone(wgpu::CallbackMode::AllowSpontaneous, + [ctx, params_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) { + if (status != wgpu::QueueWorkDoneStatus::Success) { + GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", + message.data); + } + ctx->param_buf_pool.free_bufs({ params_bufs }); + }); + } else { + // Lock the context mutex when pushing to the staging vectors. + std::lock_guard lock(ctx->mutex); + // Enqueue commands and only submit if we have enough staged commands + ctx->staged_command_bufs.push_back(commands); + ctx->staged_param_bufs.push_back(params_bufs); + if (ctx->staged_command_bufs.size() == WEBGPU_COMMAND_SUBMIT_BATCH_SIZE) { + ggml_backend_webgpu_submit_queue(ctx); + } + } } -static void ggml_backend_webgpu_wait_on_submission(webgpu_context ctx) { - // Wait for the queue to finish processing all commands - ctx->instance.WaitAny(ctx->queue.OnSubmittedWorkDone(wgpu::CallbackMode::WaitAnyOnly, - [](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) { - if (status != wgpu::QueueWorkDoneStatus::Success) { - GGML_LOG_ERROR("ggml_webgpu: Failed to wait on queue: %s\n", message.data); - } - }), - UINT64_MAX - ); +static void ggml_backend_webgpu_buffer_memset(webgpu_context & ctx, + wgpu::Buffer & buf, + uint32_t value, + size_t offset, + size_t size) { + std::vector params = { (uint32_t) offset, (uint32_t) size, value }; + std::vector entries = { + { .binding = 0, .buffer = buf, .offset = 0, .size = buf.GetSize() } + }; + size_t bytes_per_wg = ctx->limits.maxComputeWorkgroupSizeX * ctx->memset_bytes_per_thread; + uint32_t wg_x = ((size + 3) + bytes_per_wg - 1) / bytes_per_wg; + ggml_backend_webgpu_build_and_enqueue(ctx, ctx->memset_pipeline, params, entries, wg_x, true); +} + +static size_t ggml_backend_webgpu_tensor_offset(const ggml_tensor * tensor) { + return webgpu_tensor_offset(tensor) + tensor->view_offs; +} + +static wgpu::Buffer ggml_backend_webgpu_tensor_buf(const ggml_tensor * tensor) { + ggml_backend_webgpu_buffer_context * ctx = (ggml_backend_webgpu_buffer_context *) tensor->buffer->context; + return ctx->buffer; } /** End WebGPU Actions */ @@ -218,218 +355,146 @@ static void ggml_backend_webgpu_wait_on_submission(webgpu_context ctx) { /** GGML Backend Interface */ static const char * ggml_backend_webgpu_name(ggml_backend_t backend) { - ggml_backend_webgpu_context * ctx = (ggml_backend_webgpu_context *)backend->context; + ggml_backend_webgpu_context * ctx = (ggml_backend_webgpu_context *) backend->context; return ctx->name.c_str(); } static void ggml_backend_webgpu_free(ggml_backend_t backend) { - ggml_backend_webgpu_context * ctx = (ggml_backend_webgpu_context *)backend->context; + ggml_backend_webgpu_context * ctx = (ggml_backend_webgpu_context *) backend->context; WEBGPU_LOG_DEBUG("ggml_backend_webgpu_free(" << ctx->name << ")"); // TODO: cleanup GGML_UNUSED(ctx); } +static void ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { + size_t src_offset = ggml_backend_webgpu_tensor_offset(src); + // assumes power of 2 offset alignment + size_t src_misalignment = src_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1); + // align to minimum offset alignment + src_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1); + size_t dst_offset = ggml_backend_webgpu_tensor_offset(dst); + size_t dst_misalignment = dst_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1); + dst_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1); + uint32_t ne = (uint32_t) ggml_nelements(dst); + std::vector params = { ne, + (uint32_t) (src_misalignment / ggml_type_size(src->type)), + (uint32_t) (dst_misalignment / ggml_type_size(dst->type)), + // Convert byte-strides to element-strides + (uint32_t) (src->nb[0] / ggml_type_size(src->type)), + (uint32_t) (src->nb[1] / ggml_type_size(src->type)), + (uint32_t) (src->nb[2] / ggml_type_size(src->type)), + (uint32_t) (src->nb[3] / ggml_type_size(src->type)), + (uint32_t) (dst->nb[0] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), + // Logical shape — same for both tensors even if permuted + (uint32_t) src->ne[0], + (uint32_t) src->ne[1], + (uint32_t) src->ne[2], + (uint32_t) src->ne[3] }; + + std::vector entries = { + { .binding = 0, + .buffer = ggml_backend_webgpu_tensor_buf(src), + .offset = src_offset, + .size = (ggml_nbytes(src) + src_misalignment + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) & + ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1) }, + { .binding = 1, + .buffer = ggml_backend_webgpu_tensor_buf(dst), + .offset = dst_offset, + .size = (ggml_nbytes(dst) + dst_misalignment + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) & + ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1) } + }; + + size_t max_wg_size = ctx->limits.maxComputeWorkgroupSizeX; + uint32_t wg_x = (ne + max_wg_size - 1) / max_wg_size; + ggml_backend_webgpu_build_and_enqueue(ctx, ctx->cpy_pipeline, params, entries, wg_x); +} + +static void ggml_webgpu_mul_mat(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) { + std::vector params = { + (uint32_t) dst->ne[1], // number of rows in result (M) + (uint32_t) dst->ne[0], // number of columns in result (N) + (uint32_t) src0->ne[0], // number of columns in src0/src1 (K) + (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), // stride (elements) of src0 in dimension 1 + (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), // stride (elements) of src1 in dimension 1 + (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), // stride (elements) of src0 in dimension 2 + (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), // stride (elements) of src1 in dimension 2 + (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), // stride (elements) of src0 in dimension 3 + (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)), // stride (elements) of src1 in dimension 3 + (uint32_t) src0->ne[2], // batch size in dimension 2 + (uint32_t) src0->ne[3], // batch size in dimension 3 + (uint32_t) (src1->ne[2] / src0->ne[2]), // broadcast in dimension 2 + (uint32_t) (src1->ne[3] / src0->ne[3]) // broadcast in dimension 3 + }; + + std::vector entries = { + { .binding = 0, + .buffer = ggml_backend_webgpu_tensor_buf(src0), + .offset = ggml_backend_webgpu_tensor_offset(src0), + .size = ggml_nbytes(src0) }, + { .binding = 1, + .buffer = ggml_backend_webgpu_tensor_buf(src1), + .offset = ggml_backend_webgpu_tensor_offset(src1), + .size = ggml_nbytes(src1) }, + { .binding = 2, + .buffer = ggml_backend_webgpu_tensor_buf(dst), + .offset = ggml_backend_webgpu_tensor_offset(dst), + .size = ggml_nbytes(dst) } + }; + + uint32_t wg_x = + (dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3] + WEBGPU_MUL_MAT_WG_SIZE - 1) / WEBGPU_MUL_MAT_WG_SIZE; + ggml_backend_webgpu_build_and_enqueue(ctx, ctx->mul_mat_pipeline, params, entries, wg_x); +} + // Returns true if node has enqueued work into the queue, false otherwise -static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node){ +static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) { if (ggml_is_empty(node)) { return false; } - WEBGPU_LOG_DEBUG("ggml_webgpu_encode_node(" << node << ", " << ggml_op_name(node->op) << ")"); + ggml_tensor * src0 = node->src[0]; + ggml_tensor * src1 = node->src[1]; switch (node->op) { - // no-ops + // no-ops case GGML_OP_NONE: case GGML_OP_VIEW: case GGML_OP_PERMUTE: return false; - - case GGML_OP_CPY: { - std::lock_guard lock(ctx->mutex); - const ggml_tensor * src = node->src[0]; - ggml_backend_webgpu_buffer_context * src_ctx = (ggml_backend_webgpu_buffer_context *) src->buffer->context; - size_t src_offset = webgpu_tensor_offset(src) + src->view_offs; - // assumes power of 2 offset alignment - size_t src_misalignment = src_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1); - // align to minimum offset alignment - src_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1); - ggml_backend_webgpu_buffer_context * dst_ctx = (ggml_backend_webgpu_buffer_context *) node->buffer->context; - size_t dst_offset = webgpu_tensor_offset(node) + node->view_offs; - size_t dst_misalignment = dst_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1); - dst_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1); - - wgpu::Device device = ctx->device; - ggml_backend_webgpu_map_buffer(ctx, ctx->cpy_params_host_buf, - wgpu::MapMode::Write, 0, ctx->cpy_params_host_buf.GetSize()); - uint32_t * params = (uint32_t *) ctx->cpy_params_host_buf.GetMappedRange(); - uint32_t ne = (uint32_t)ggml_nelements(node); - params[0] = ne; - params[1] = src_misalignment/ggml_type_size(src->type); - params[2] = dst_misalignment/ggml_type_size(node->type); - - // Convert byte-strides to element-strides - params[3] = (uint32_t)src->nb[0]/ggml_type_size(src->type); - params[4] = (uint32_t)src->nb[1]/ggml_type_size(src->type); - params[5] = (uint32_t)src->nb[2]/ggml_type_size(src->type); - params[6] = (uint32_t)src->nb[3]/ggml_type_size(src->type); - params[7] = (uint32_t)node->nb[0]/ggml_type_size(node->type); - params[8] = (uint32_t)node->nb[1]/ggml_type_size(node->type); - params[9] = (uint32_t)node->nb[2]/ggml_type_size(node->type); - params[10] = (uint32_t)node->nb[3]/ggml_type_size(node->type); - // Logical shape — same for both tensors even if permuted - params[11] = (uint32_t)(src->ne[0]); - params[12] = (uint32_t)(src->ne[1]); - params[13] = (uint32_t)(src->ne[2]); - params[14] = (uint32_t)(src->ne[3]); - - ctx->cpy_params_host_buf.Unmap(); - - wgpu::BindGroupEntry entries[3]; - entries[0].binding = 0; - entries[0].buffer = src_ctx->buffer; - entries[0].offset = src_offset; - entries[0].size = (ggml_nbytes(src) + src_misalignment + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) & ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1); - - entries[1].binding = 1; - entries[1].buffer = dst_ctx->buffer; - entries[1].offset = dst_offset; - entries[1].size = (ggml_nbytes(node) + dst_misalignment + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) & ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1); - - entries[2].binding = 2; - entries[2].buffer = ctx->cpy_params_dev_buf; - entries[2].offset = 0; - entries[2].size = ctx->cpy_params_dev_buf.GetSize(); - - wgpu::BindGroupDescriptor bind_group_desc; - bind_group_desc.layout = ctx->cpy_pipeline.GetBindGroupLayout(0); - bind_group_desc.label = "ggml_op_cpy"; - bind_group_desc.entryCount = 3; - bind_group_desc.entries = entries; - wgpu::BindGroup bind_group = device.CreateBindGroup(&bind_group_desc); - - wgpu::CommandEncoder encoder = device.CreateCommandEncoder(); - encoder.CopyBufferToBuffer( - ctx->cpy_params_host_buf, 0, - ctx->cpy_params_dev_buf, 0, - ctx->cpy_params_dev_buf.GetSize() - ); - wgpu::ComputePassEncoder pass = encoder.BeginComputePass(); - pass.SetPipeline(ctx->cpy_pipeline); - pass.SetBindGroup(0, bind_group); - size_t max_wg_size = ctx->limits.maxComputeWorkgroupSizeX; - pass.DispatchWorkgroups((ne + max_wg_size - 1) / max_wg_size); - pass.End(); - wgpu::CommandBuffer commands = encoder.Finish(); - - // TODO, don't submit here, batch submissions - ctx->queue.Submit(1, &commands); - // TODO, don't wait on submission here - ggml_backend_webgpu_wait_on_submission(ctx); - return true; - } - + case GGML_OP_CPY: + { + ggml_webgpu_cpy(ctx, src0, node); + break; + } case GGML_OP_MUL_MAT: - { - const ggml_tensor * src0 = node->src[0]; - ggml_backend_webgpu_buffer_context * src0_ctx = (ggml_backend_webgpu_buffer_context *) src0->buffer->context; - size_t src0_offset = webgpu_tensor_offset(src0) + src0->view_offs; - const ggml_tensor * src1 = node->src[1]; - ggml_backend_webgpu_buffer_context * src1_ctx = (ggml_backend_webgpu_buffer_context *) src1->buffer->context; - size_t src1_offset = webgpu_tensor_offset(src1) + src1->view_offs; - ggml_backend_webgpu_buffer_context * dst_ctx = (ggml_backend_webgpu_buffer_context *) node->buffer->context; - - size_t dst_offset = webgpu_tensor_offset(node) + node->view_offs; - - wgpu::Device device = ctx->device; - - // map the host parameters buffer - ggml_backend_webgpu_map_buffer(ctx, ctx->mul_mat_params_host_buf, - wgpu::MapMode::Write, 0, ctx->mul_mat_params_host_buf.GetSize()); - uint32_t * params = (uint32_t *) ctx->mul_mat_params_host_buf.GetMappedRange(); - - params[0] = (uint32_t)node->ne[1]; // number of rows in result (M) - params[1] = (uint32_t)node->ne[0]; // number of columns in result (N) - params[2] = (uint32_t)src0->ne[0]; // number of columns in src0/src1 (K) - - params[3] = (uint32_t)src0->nb[1]/ggml_type_size(src0->type); // stride (elements) of src0 in dimension 1 - params[4] = (uint32_t)src1->nb[1]/ggml_type_size(src1->type); // stride (elements) of src1 in dimension 1 - params[5] = (uint32_t)src0->nb[2]/ggml_type_size(src0->type); // stride (elements) of src0 in dimension 2 - params[6] = (uint32_t)src1->nb[2]/ggml_type_size(src1->type); // stride (elements) of src1 in dimension 2 - params[7] = (uint32_t)src0->nb[3]/ggml_type_size(src0->type); // stride (elements) of src0 in dimension 3 - params[8] = (uint32_t)src1->nb[3]/ggml_type_size(src1->type); // stride (elements) of src1 in dimension 3 - - params[9] = (uint32_t)src0->ne[2]; // batch size in dimension 2 - params[10] = (uint32_t)src0->ne[3]; // batch size in dimension 3 - params[11] = (uint32_t)(src1->ne[2]/src0->ne[2]); // broadcast in dimension 2 - params[12] = (uint32_t)(src1->ne[3]/src0->ne[3]); // broadcast in dimension 3 - - ctx->mul_mat_params_host_buf.Unmap(); - - wgpu::BindGroupEntry entries[4]; - entries[0].binding = 0; - entries[0].buffer = src0_ctx->buffer; - entries[0].offset = src0_offset; - entries[0].size = ggml_nbytes(src0); - - entries[1].binding = 1; - entries[1].buffer = src1_ctx->buffer; - entries[1].offset = src1_offset; - entries[1].size = ggml_nbytes(src1); - - entries[2].binding = 2; - entries[2].buffer = dst_ctx->buffer; - entries[2].offset = dst_offset; - entries[2].size = ggml_nbytes(node); - - entries[3].binding = 3; - entries[3].buffer = ctx->mul_mat_params_dev_buf; - entries[3].offset = 0; - entries[3].size = ctx->mul_mat_params_dev_buf.GetSize(); - - wgpu::BindGroupDescriptor bind_group_desc; - bind_group_desc.layout = ctx->mul_mat_pipeline.GetBindGroupLayout(0); - bind_group_desc.entryCount = 4; - bind_group_desc.label = "ggml_op_mul_mat"; - bind_group_desc.entries = entries; - wgpu::BindGroup bind_group = device.CreateBindGroup(&bind_group_desc); - - wgpu::CommandEncoder encoder = device.CreateCommandEncoder(); - encoder.CopyBufferToBuffer( - ctx->mul_mat_params_host_buf, 0, - ctx->mul_mat_params_dev_buf, 0, - ctx->mul_mat_params_dev_buf.GetSize() - ); - wgpu::ComputePassEncoder pass = encoder.BeginComputePass(); - pass.SetPipeline(ctx->mul_mat_pipeline); - pass.SetBindGroup(0, bind_group); - pass.DispatchWorkgroups((node->ne[0] * node->ne[1] * node->ne[2] * node->ne[3] + WEBGPU_MUL_MAT_WG_SIZE - 1) / WEBGPU_MUL_MAT_WG_SIZE); - pass.End(); - wgpu::CommandBuffer commands = encoder.Finish(); - - // TODO, don't submit here, batch submissions - ctx->queue.Submit(1, &commands); - // TODO, don't wait on submission here - ggml_backend_webgpu_wait_on_submission(ctx); - return true; - } - + { + ggml_webgpu_mul_mat(ctx, src0, src1, node); + break; + } default: return false; } + return true; } static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) { WEBGPU_LOG_DEBUG("ggml_backend_webgpu_graph_compute(" << cgraph->n_nodes << " nodes)"); ggml_backend_webgpu_context * backend_ctx = static_cast(backend->context); - webgpu_context ctx = backend_ctx->webgpu_ctx; + webgpu_context ctx = backend_ctx->webgpu_ctx; for (int i = 0; i < cgraph->n_nodes; i++) { ggml_webgpu_encode_node(ctx, cgraph->nodes[i]); } + ggml_backend_webgpu_submit_queue(ctx); + ggml_backend_webgpu_wait_on_submission(ctx); + return GGML_STATUS_SUCCESS; } @@ -465,49 +530,69 @@ static void * ggml_backend_webgpu_buffer_get_base(ggml_backend_buffer_t buffer) return webgpu_ptr_base; } -static void ggml_backend_webgpu_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { +static void ggml_backend_webgpu_buffer_memset_tensor(ggml_backend_buffer_t buffer, + ggml_tensor * tensor, + uint8_t value, + size_t offset, + size_t size) { if (size == 0) { WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor: size is zero, nothing to do."); return; } - WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor(" << buffer << ", " << tensor << ", " << value << ", " << offset << ", " << size << ")"); + WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor(" << buffer << ", " << tensor << ", " << value << ", " + << offset << ", " << size << ")"); ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context; + size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset; + // This is a trick to set all bytes of a u32 to the same 1 byte value. - uint32_t val32 = (uint32_t)value * 0x01010101; + uint32_t val32 = (uint32_t) value * 0x01010101; ggml_backend_webgpu_buffer_memset(buf_ctx->webgpu_ctx, buf_ctx->buffer, val32, total_offset, size); } -static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { - WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_set_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")"); - ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context; - webgpu_context webgpu_ctx = buf_ctx->webgpu_ctx; +static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer, + ggml_tensor * tensor, + const void * data, + size_t offset, + size_t size) { + WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_set_tensor(" << buffer << ", " << tensor << ", " << data << ", " + << offset << ", " << size << ")"); + ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context; + webgpu_context webgpu_ctx = buf_ctx->webgpu_ctx; size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset; - webgpu_ctx->queue.WriteBuffer(buf_ctx->buffer, total_offset, data, (size/4)*4); + webgpu_ctx->queue.WriteBuffer(buf_ctx->buffer, total_offset, data, (size / 4) * 4); if (size % 4 != 0) { // If size is not a multiple of 4, we need to memset the remaining bytes size_t remaining_size = size % 4; + // pack the remaining bytes into a uint32_t uint32_t val32 = 0; + for (size_t i = 0; i < remaining_size; i++) { - ((uint8_t *)&val32)[i] = ((const uint8_t *)data)[size - remaining_size + i]; + ((uint8_t *) &val32)[i] = ((const uint8_t *) data)[size - remaining_size + i]; } // memset the remaining bytes - ggml_backend_webgpu_buffer_memset(webgpu_ctx, buf_ctx->buffer, val32, total_offset + (size - remaining_size), remaining_size); + ggml_backend_webgpu_buffer_memset( + webgpu_ctx, buf_ctx->buffer, val32, total_offset + (size - remaining_size), remaining_size); } } -static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { - WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_get_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")"); +static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer, + const ggml_tensor * tensor, + void * data, + size_t offset, + size_t size) { + WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_get_tensor(" << buffer << ", " << tensor << ", " << data << ", " + << offset << ", " << size << ")"); - ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context; - webgpu_context webgpu_ctx = buf_ctx->webgpu_ctx; - wgpu::Device device = webgpu_ctx->device; + ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context; + webgpu_context webgpu_ctx = buf_ctx->webgpu_ctx; + wgpu::Device device = webgpu_ctx->device; size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset; @@ -517,22 +602,25 @@ static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer, final_size = size + (4 - (size % 4)); } - std::lock_guard lock(webgpu_ctx->mutex); + std::lock_guard lock(webgpu_ctx->get_tensor_mutex); - if (webgpu_ctx->get_tensor_staging_buf == nullptr || - webgpu_ctx->get_tensor_staging_buf.GetSize() < final_size) { + if (webgpu_ctx->get_tensor_staging_buf == nullptr || webgpu_ctx->get_tensor_staging_buf.GetSize() < final_size) { // Create a new staging buffer if it doesn't exist or is too small if (webgpu_ctx->get_tensor_staging_buf) { webgpu_ctx->get_tensor_staging_buf.Destroy(); } - ggml_webgpu_create_buffer(device, webgpu_ctx->get_tensor_staging_buf, final_size, - wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "get_tensor_staging_buf"); + ggml_webgpu_create_buffer(device, + webgpu_ctx->get_tensor_staging_buf, + final_size, + wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, + "get_tensor_staging_buf"); } // Copy the data from the buffer to the staging buffer wgpu::CommandEncoder encoder = device.CreateCommandEncoder(); encoder.CopyBufferToBuffer(buf_ctx->buffer, total_offset, webgpu_ctx->get_tensor_staging_buf, 0, final_size); wgpu::CommandBuffer commands = encoder.Finish(); + // Submit the command buffer to the queue webgpu_ctx->queue.Submit(1, &commands); @@ -548,7 +636,6 @@ static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer, static void ggml_backend_webgpu_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_clear(" << buffer << ", " << (uint32_t) value << ")"); - ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context; ggml_backend_webgpu_buffer_memset(buf_ctx->webgpu_ctx, buf_ctx->buffer, value, 0, buffer->size); } @@ -556,13 +643,13 @@ static void ggml_backend_webgpu_buffer_clear(ggml_backend_buffer_t buffer, uint8 static ggml_backend_buffer_i ggml_backend_webgpu_buffer_interface = { /* .free_buffer = */ ggml_backend_webgpu_buffer_free_buffer, /* .get_base = */ ggml_backend_webgpu_buffer_get_base, - /* .init_tensor = */ NULL, // TODO: optional, needed? + /* .init_tensor = */ NULL, // TODO: optional, needed? /* .memset_tensor = */ ggml_backend_webgpu_buffer_memset_tensor, /* .set_tensor = */ ggml_backend_webgpu_buffer_set_tensor, /* .get_tensor = */ ggml_backend_webgpu_buffer_get_tensor, - /* .cpy_tensor = */ NULL, // TODO: optional, implement this + /* .cpy_tensor = */ NULL, // TODO: optional, implement this /* .clear = */ ggml_backend_webgpu_buffer_clear, - /* .reset = */ NULL, // TODO: optional, think it coordinates with .init_tensor + /* .reset = */ NULL, // TODO: optional, think it coordinates with .init_tensor }; /* End GGML Backend Buffer Interface */ @@ -574,13 +661,17 @@ static const char * ggml_backend_webgpu_buffer_type_get_name(ggml_backend_buffer return ctx->device_name.c_str(); } -static ggml_backend_buffer_t ggml_backend_webgpu_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { +static ggml_backend_buffer_t ggml_backend_webgpu_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, + size_t size) { WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_type_alloc_buffer(" << size << ")"); ggml_backend_webgpu_device_context * ctx = static_cast(buft->device->context); wgpu::Buffer buf; - ggml_webgpu_create_buffer(ctx->webgpu_ctx->device, buf, size, - wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst, "allocated_buffer"); + ggml_webgpu_create_buffer(ctx->webgpu_ctx->device, + buf, + size, + wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst, + "allocated_buffer"); ggml_backend_webgpu_buffer_context * buf_ctx = new ggml_backend_webgpu_buffer_context(ctx->webgpu_ctx, buf); @@ -615,8 +706,8 @@ static const char * ggml_backend_webgpu_device_get_description(ggml_backend_dev_ static void ggml_backend_webgpu_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { ggml_backend_webgpu_device_context * ctx = static_cast(dev->context); // TODO: what do we actually want to return here? maxBufferSize might not be the full available memory. - *free = ctx->webgpu_ctx->limits.maxBufferSize; - *total = ctx->webgpu_ctx->limits.maxBufferSize; + *free = ctx->webgpu_ctx->limits.maxBufferSize; + *total = ctx->webgpu_ctx->limits.maxBufferSize; } static enum ggml_backend_dev_type ggml_backend_webgpu_device_get_type(ggml_backend_dev_t dev) { @@ -639,98 +730,93 @@ static void ggml_backend_webgpu_device_get_props(ggml_backend_dev_t dev, struct static ggml_guid_t ggml_backend_webgpu_guid(void) { static const char * guid_str = "__ggml_webgpu :)"; - return reinterpret_cast((void *)guid_str); + return reinterpret_cast((void *) guid_str); } -static void ggml_webgpu_init_memset_pipeline(webgpu_context webgpu_ctx) { +static void ggml_webgpu_init_memset_pipeline(webgpu_context & webgpu_ctx) { // we use the maximum workgroup size for the memset pipeline size_t max_wg_size = webgpu_ctx->limits.maxComputeWorkgroupSizeX; size_t max_threads = max_wg_size * webgpu_ctx->limits.maxComputeWorkgroupsPerDimension; // Size the bytes_per_thread so that the largest buffer size can be handled - webgpu_ctx->memset_bytes_per_thread = (webgpu_ctx->limits.maxStorageBufferBindingSize + max_threads - 1) / max_threads; + webgpu_ctx->memset_bytes_per_thread = + (webgpu_ctx->limits.maxStorageBufferBindingSize + max_threads - 1) / max_threads; std::vector constants(2); - constants[0].key = "wg_size"; + constants[0].key = "wg_size"; constants[0].value = max_wg_size; - constants[1].key = "bytes_per_thread"; + constants[1].key = "bytes_per_thread"; constants[1].value = webgpu_ctx->memset_bytes_per_thread; ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->memset_pipeline, wgsl_memset, "memset", constants); - ggml_webgpu_create_buffer(webgpu_ctx->device, webgpu_ctx->memset_params_dev_buf, - 3 * sizeof(uint32_t), // 3 parameters: buffer size, offset, value - wgpu::BufferUsage::Uniform | wgpu::BufferUsage::CopyDst, "memset_params_dev_buf"); - ggml_webgpu_create_buffer(webgpu_ctx->device, webgpu_ctx->memset_params_host_buf, - 3 * sizeof(uint32_t), wgpu::BufferUsage::MapWrite | wgpu::BufferUsage::CopySrc, "memset_params_host_buf"); } -static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context webgpu_ctx) { +static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline, wgsl_mul_mat, "mul_mat"); - ggml_webgpu_create_buffer(webgpu_ctx->device, webgpu_ctx->mul_mat_params_dev_buf, WEBGPU_MUL_MAT_PARAMS_SIZE, - wgpu::BufferUsage::Uniform | wgpu::BufferUsage::CopyDst, "mul_mat_params_dev_buf"); - ggml_webgpu_create_buffer(webgpu_ctx->device, webgpu_ctx->mul_mat_params_host_buf, WEBGPU_MUL_MAT_PARAMS_SIZE, - wgpu::BufferUsage::MapWrite | wgpu::BufferUsage::CopySrc, "mul_mat_params_host_buf"); } -static void ggml_webgpu_init_cpy_pipeline(webgpu_context webgpu_ctx) { +static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) { std::vector constants(1); - constants[0].key = "wg_size"; + constants[0].key = "wg_size"; constants[0].value = webgpu_ctx->limits.maxComputeWorkgroupSizeX; - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline, wgsl_cpy, "cpy", constants); - ggml_webgpu_create_buffer(webgpu_ctx->device, webgpu_ctx->cpy_params_dev_buf, WEBGPU_CPY_PARAMS_SIZE, - wgpu::BufferUsage::Uniform | wgpu::BufferUsage::CopyDst, "cpy_params_dev_buf"); - ggml_webgpu_create_buffer(webgpu_ctx->device, webgpu_ctx->cpy_params_host_buf, WEBGPU_CPY_PARAMS_SIZE, - wgpu::BufferUsage::MapWrite | wgpu::BufferUsage::CopySrc, "cpy_params_host_buf"); } -// TODO: Make thread safe if multiple devices are used static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, const char * params) { GGML_UNUSED(params); WEBGPU_LOG_DEBUG("ggml_backend_webgpu_device_init()"); - ggml_backend_webgpu_device_context * dev_ctx = static_cast(dev->context); - webgpu_context webgpu_ctx = dev_ctx->webgpu_ctx; + ggml_backend_webgpu_device_context * dev_ctx = static_cast(dev->context); + webgpu_context webgpu_ctx = dev_ctx->webgpu_ctx; - std::lock_guard lock(webgpu_ctx->mutex); - - if (!webgpu_ctx->device_initialized) { + // Multiple threads may try to initialize the device + std::lock_guard lock(webgpu_ctx->init_mutex); + if (!webgpu_ctx->device_init) { // Initialize device - wgpu::DeviceDescriptor dev_desc; - dev_desc.requiredLimits = &webgpu_ctx->limits; - dev_desc.requiredFeatures = webgpu_ctx->features.features; - dev_desc.requiredFeatureCount = webgpu_ctx->features.featureCount; - dev_desc.SetDeviceLostCallback(wgpu::CallbackMode::AllowSpontaneous, - [](const wgpu::Device& device, wgpu::DeviceLostReason reason, wgpu::StringView message) { + std::vector required_features = { wgpu::FeatureName::ShaderF16, wgpu::FeatureName::ImplicitDeviceSynchronization }; + wgpu::DeviceDescriptor dev_desc; + dev_desc.requiredLimits = &webgpu_ctx->limits; + dev_desc.requiredFeatures = required_features.data(); + dev_desc.requiredFeatureCount = required_features.size(); + dev_desc.SetDeviceLostCallback( + wgpu::CallbackMode::AllowSpontaneous, + [](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) { GGML_UNUSED(device); - GGML_LOG_ERROR("ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast(reason), message.data); - }); + GGML_LOG_ERROR( + "ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast(reason), message.data); + }); dev_desc.SetUncapturedErrorCallback( - [](const wgpu::Device& device, wgpu::ErrorType reason, wgpu::StringView message) { + [](const wgpu::Device & device, wgpu::ErrorType reason, wgpu::StringView message) { GGML_UNUSED(device); - GGML_LOG_ERROR("ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast(reason), message.data); - }); - webgpu_ctx->instance.WaitAny(webgpu_ctx->adapter.RequestDevice(&dev_desc, wgpu::CallbackMode::WaitAnyOnly, - [webgpu_ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) { - if (status != wgpu::RequestDeviceStatus::Success) { - GGML_LOG_ERROR("ggml_webgpu: Failed to get a device: %s\n", message.data); - return; - } - webgpu_ctx->device = device; - }), - UINT64_MAX - ); + GGML_LOG_ERROR( + "ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast(reason), message.data); + }); + webgpu_ctx->instance.WaitAny( + webgpu_ctx->adapter.RequestDevice( + &dev_desc, + wgpu::CallbackMode::AllowSpontaneous, + [webgpu_ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) { + if (status != wgpu::RequestDeviceStatus::Success) { + GGML_LOG_ERROR("ggml_webgpu: Failed to get a device: %s\n", message.data); + return; + } + webgpu_ctx->device = std::move(device); + }), + UINT64_MAX); GGML_ASSERT(webgpu_ctx->device != nullptr); // Initialize (compute) queue webgpu_ctx->queue = webgpu_ctx->device.GetQueue(); + // Create buffer pool for shader parameters + webgpu_ctx->param_buf_pool.init(webgpu_ctx->device); + ggml_webgpu_init_memset_pipeline(webgpu_ctx); ggml_webgpu_init_mul_mat_pipeline(webgpu_ctx); ggml_webgpu_init_cpy_pipeline(webgpu_ctx); - webgpu_ctx->device_initialized = true; + webgpu_ctx->device_init = true; } static ggml_backend_webgpu_context backend_ctx; - backend_ctx.name = GGML_WEBGPU_NAME + std::string(": ") + dev_ctx->device_name; + backend_ctx.name = GGML_WEBGPU_NAME + std::string(": ") + dev_ctx->device_name; backend_ctx.webgpu_ctx = webgpu_ctx; // See GGML Backend Interface section @@ -748,14 +834,15 @@ static ggml_backend_buffer_type_t ggml_backend_webgpu_device_get_buffer_type(ggm // See GGML Backend Buffer Type Interface section static struct ggml_backend_buffer_type ggml_backend_webgpu_buffer_type = { /* .iface = */ { - /* .get_name = */ ggml_backend_webgpu_buffer_type_get_name, - /* .alloc_buffer = */ ggml_backend_webgpu_buffer_type_alloc_buffer, - /* .get_alignment = */ ggml_backend_webgpu_buffer_type_get_alignment, - /* .get_max_size = */ ggml_backend_webgpu_buffer_type_get_max_size, - /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes - /* .is_host = */ NULL, // defaults to false + /* .get_name = */ ggml_backend_webgpu_buffer_type_get_name, + /* .alloc_buffer = */ ggml_backend_webgpu_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_webgpu_buffer_type_get_alignment, + /* .get_max_size = */ ggml_backend_webgpu_buffer_type_get_max_size, + /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes + /* .is_host = */ NULL, // defaults to false }, - /* .device = */ dev, + /* .device = */ + dev, /* .context = */ NULL, }; @@ -764,7 +851,7 @@ static ggml_backend_buffer_type_t ggml_backend_webgpu_device_get_buffer_type(ggm static bool ggml_backend_webgpu_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { GGML_UNUSED(dev); - return buft->iface.get_name == ggml_backend_webgpu_buffer_type_get_name; + return buft->iface.get_name == ggml_backend_webgpu_buffer_type_get_name; } static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) { @@ -827,30 +914,38 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t webgpu_context ctx = reg_ctx->webgpu_ctx; wgpu::RequestAdapterOptions options = {}; - auto callback = [](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char *message, void *userdata) { - if (status != wgpu::RequestAdapterStatus::Success) { - GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message); - return; - } - *static_cast(userdata) = adapter; - }; - void *userdata = &ctx->adapter; - ctx->instance.WaitAny(ctx->instance.RequestAdapter(&options, wgpu::CallbackMode::WaitAnyOnly, callback, userdata), UINT64_MAX); + auto callback = + [](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message, void * userdata) { + if (status != wgpu::RequestAdapterStatus::Success) { + GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message); + return; + } + *static_cast(userdata) = std::move(adapter); + }; + void * userdata = &ctx->adapter; + ctx->instance.WaitAny( + ctx->instance.RequestAdapter(&options, wgpu::CallbackMode::AllowSpontaneous, callback, userdata), UINT64_MAX); GGML_ASSERT(ctx->adapter != nullptr); ctx->adapter.GetLimits(&ctx->limits); - ctx->adapter.GetFeatures(&ctx->features); wgpu::AdapterInfo info{}; ctx->adapter.GetInfo(&info); static ggml_backend_webgpu_device_context device_ctx; - device_ctx.webgpu_ctx = ctx; + device_ctx.webgpu_ctx = ctx; device_ctx.device_name = GGML_WEBGPU_NAME; device_ctx.device_desc = std::string(info.description.data); - GGML_LOG_INFO("ggml_webgpu: adapter_info: vendor_id: %u | vendor: %s | architecture: %s | device_id: %u | name: %s | device_desc: %s\n", - info.vendorID, info.vendor.data, info.architecture.data, info.deviceID, info.device.data, info.description.data); + GGML_LOG_INFO( + "ggml_webgpu: adapter_info: vendor_id: %u | vendor: %s | architecture: %s | device_id: %u | name: %s | " + "device_desc: %s\n", + info.vendorID, + info.vendor.data, + info.architecture.data, + info.deviceID, + info.device.data, + info.description.data); // See GGML Backend Device Interface section static ggml_backend_device device = { @@ -861,7 +956,6 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t return &device; } - static const struct ggml_backend_reg_i ggml_backend_webgpu_reg_i = { /* .get_name = */ ggml_backend_webgpu_reg_get_name, /* .get_device_count = */ ggml_backend_webgpu_reg_get_device_count, @@ -871,23 +965,21 @@ static const struct ggml_backend_reg_i ggml_backend_webgpu_reg_i = { /* End GGML Backend Registration Interface */ -// TODO: Does this need to be thread safe? Is it only called once? ggml_backend_reg_t ggml_backend_webgpu_reg() { WEBGPU_LOG_DEBUG("ggml_backend_webgpu_reg()"); webgpu_context webgpu_ctx = std::make_shared(); - webgpu_ctx->device_initialized = false; static ggml_backend_webgpu_reg_context ctx; - ctx.webgpu_ctx = webgpu_ctx; - ctx.name = GGML_WEBGPU_NAME; + ctx.webgpu_ctx = webgpu_ctx; + ctx.name = GGML_WEBGPU_NAME; ctx.device_count = 1; - wgpu::InstanceDescriptor instance_descriptor{}; - std::vector instance_features = {wgpu::InstanceFeatureName::TimedWaitAny}; - instance_descriptor.requiredFeatures = instance_features.data(); - instance_descriptor.requiredFeatureCount = instance_features.size(); - webgpu_ctx->instance = wgpu::CreateInstance(&instance_descriptor); + wgpu::InstanceDescriptor instance_descriptor{}; + std::vector instance_features = { wgpu::InstanceFeatureName::TimedWaitAny }; + instance_descriptor.requiredFeatures = instance_features.data(); + instance_descriptor.requiredFeatureCount = instance_features.size(); + webgpu_ctx->instance = wgpu::CreateInstance(&instance_descriptor); GGML_ASSERT(webgpu_ctx->instance != nullptr); static ggml_backend_reg reg = { diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 5707085cb..e2d81dd98 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -105,6 +105,7 @@ class Keys: EXPERT_WEIGHTS_NORM = "{arch}.expert_weights_norm" EXPERT_GATING_FUNC = "{arch}.expert_gating_func" MOE_EVERY_N_LAYERS = "{arch}.moe_every_n_layers" + NEXTN_PREDICT_LAYERS = "{arch}.nextn_predict_layers" POOLING_TYPE = "{arch}.pooling_type" LOGIT_SCALE = "{arch}.logit_scale" DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id" @@ -357,6 +358,7 @@ class MODEL_ARCH(IntEnum): DEEPSEEK2 = auto() CHATGLM = auto() GLM4 = auto() + GLM4_MOE = auto() BITNET = auto() T5 = auto() T5ENCODER = auto() @@ -614,6 +616,13 @@ class MODEL_TENSOR(IntEnum): A_MMPROJ_FC = auto() A_MM_NORM_PRE = auto() A_MM_NORM_MID = auto() + # nextn/mtp + NEXTN_EH_PROJ = auto() + NEXTN_EMBED_TOKENS = auto() + NEXTN_ENORM = auto() + NEXTN_HNORM = auto() + NEXTN_SHARED_HEAD_HEAD = auto() + NEXTN_SHARED_HEAD_NORM = auto() MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { @@ -678,6 +687,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.DEEPSEEK2: "deepseek2", MODEL_ARCH.CHATGLM: "chatglm", MODEL_ARCH.GLM4: "glm4", + MODEL_ARCH.GLM4_MOE: "glm4moe", MODEL_ARCH.BITNET: "bitnet", MODEL_ARCH.T5: "t5", MODEL_ARCH.T5ENCODER: "t5encoder", @@ -936,6 +946,13 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = { MODEL_TENSOR.A_MMPROJ_FC: "mm.a.fc", MODEL_TENSOR.A_MM_NORM_PRE: "mm.a.norm_pre", MODEL_TENSOR.A_MM_NORM_MID: "mm.a.norm_mid", + # NextN/MTP + MODEL_TENSOR.NEXTN_EH_PROJ: "blk.{bid}.nextn.eh_proj", + MODEL_TENSOR.NEXTN_EMBED_TOKENS: "blk.{bid}.nextn.embed_tokens", + MODEL_TENSOR.NEXTN_ENORM: "blk.{bid}.nextn.enorm", + MODEL_TENSOR.NEXTN_HNORM: "blk.{bid}.nextn.hnorm", + MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD: "blk.{bid}.nextn.shared_head_head", + MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM: "blk.{bid}.nextn.shared_head_norm", } MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { @@ -2124,6 +2141,37 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.ATTN_POST_NORM, MODEL_TENSOR.FFN_POST_NORM, ], + MODEL_ARCH.GLM4_MOE: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_POST_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + MODEL_TENSOR.FFN_GATE_SHEXP, + MODEL_TENSOR.FFN_DOWN_SHEXP, + MODEL_TENSOR.FFN_UP_SHEXP, + MODEL_TENSOR.FFN_EXP_PROBS_B, + # NextN/MTP tensors - preserved but unused + MODEL_TENSOR.NEXTN_EH_PROJ, + MODEL_TENSOR.NEXTN_EMBED_TOKENS, + MODEL_TENSOR.NEXTN_ENORM, + MODEL_TENSOR.NEXTN_HNORM, + MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD, + MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM, + ], MODEL_ARCH.BITNET: [ MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index f4fd64ad8..89249021b 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -753,6 +753,9 @@ class GGUFWriter: def add_moe_every_n_layers(self, value: int) -> None: self.add_uint32(Keys.LLM.MOE_EVERY_N_LAYERS.format(arch=self.arch), value) + def add_nextn_predict_layers(self, count: int) -> None: + self.add_uint32(Keys.LLM.NEXTN_PREDICT_LAYERS.format(arch=self.arch), count) + def add_swin_norm(self, value: bool) -> None: self.add_bool(Keys.LLM.SWIN_NORM.format(arch=self.arch), value) diff --git a/gguf-py/gguf/scripts/gguf_new_metadata.py b/gguf-py/gguf/scripts/gguf_new_metadata.py index 63f230034..2fa5800cf 100755 --- a/gguf-py/gguf/scripts/gguf_new_metadata.py +++ b/gguf-py/gguf/scripts/gguf_new_metadata.py @@ -111,6 +111,7 @@ def main() -> None: parser.add_argument("--general-description", type=str, help="The models general.description", metavar='"Description ..."') parser.add_argument("--chat-template", type=str, help="Chat template string (or JSON string containing templates)", metavar='"{% ... %} ..."') parser.add_argument("--chat-template-config", type=Path, help="Config file containing chat template(s)", metavar='tokenizer_config.json') + parser.add_argument("--chat-template-file", type=Path, help="Jinja file containing chat template", metavar='chat_template.jinja') parser.add_argument("--pre-tokenizer", type=str, help="The models tokenizer.ggml.pre", metavar='"pre tokenizer"') parser.add_argument("--remove-metadata", action="append", type=str, help="Remove metadata (by key name) from output model", metavar='general.url') parser.add_argument("--special-token", action="append", type=str, help="Special token by value", nargs=2, metavar=(' | '.join(token_names.keys()), '""')) @@ -134,12 +135,17 @@ def main() -> None: new_metadata[gguf.Keys.Tokenizer.CHAT_TEMPLATE] = MetadataDetails(gguf.GGUFValueType.STRING, json.loads(args.chat_template) if args.chat_template.startswith('[') else args.chat_template) if args.chat_template_config: - with open(args.chat_template_config, 'r') as fp: + with open(args.chat_template_config, 'r', encoding='utf-8') as fp: config = json.load(fp) template = config.get('chat_template') if template: new_metadata[gguf.Keys.Tokenizer.CHAT_TEMPLATE] = MetadataDetails(gguf.GGUFValueType.STRING, template) + if args.chat_template_file: + with open(args.chat_template_file, 'r', encoding='utf-8') as fp: + template = fp.read() + new_metadata[gguf.Keys.Tokenizer.CHAT_TEMPLATE] = MetadataDetails(gguf.GGUFValueType.STRING, template) + if args.pre_tokenizer: new_metadata[gguf.Keys.Tokenizer.PRE] = MetadataDetails(gguf.GGUFValueType.STRING, args.pre_tokenizer) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index e6efc93fa..dd4f3d52e 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -1369,6 +1369,31 @@ class TensorNameMap: MODEL_TENSOR.A_MM_NORM_MID: ( "audio.multi_modal_projector.ln_mid", # ultravox ), + + # NextN/MTP tensors for GLM4_MOE + MODEL_TENSOR.NEXTN_EH_PROJ: ( + "model.layers.{bid}.eh_proj", + ), + + MODEL_TENSOR.NEXTN_EMBED_TOKENS: ( + "model.layers.{bid}.embed_tokens", + ), + + MODEL_TENSOR.NEXTN_ENORM: ( + "model.layers.{bid}.enorm", + ), + + MODEL_TENSOR.NEXTN_HNORM: ( + "model.layers.{bid}.hnorm", + ), + + MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD: ( + "model.layers.{bid}.shared_head.head", + ), + + MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM: ( + "model.layers.{bid}.shared_head.norm", + ), } # architecture-specific block mappings diff --git a/klite.embd b/klite.embd index d71922d66..3b8053b66 100644 --- a/klite.embd +++ b/klite.embd @@ -10547,7 +10547,6 @@ Current version indicated by LITEVER below. } function update_oai_model_list(data, dropdown) { - let isOpenrouter = (document.getElementById("customapidropdown").value==3); var lastOption = dropdown.lastElementChild; for (var i = dropdown.options.length - 1; i >= 0; i--) { var option = dropdown.options[i]; @@ -10565,7 +10564,7 @@ Current version indicated by LITEVER below. var el = document.createElement("option"); el.textContent = sortedarr[i]; el.value = sortedarr[i]; - if(isOpenrouter && sortedarr[i]=="mistralai/mistral-7b-instruct") + if(localsettings.prev_custom_endpoint_model && sortedarr[i]==localsettings.prev_custom_endpoint_model) { selidx = i; } diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index ba7bf9598..8d669bddc 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -62,6 +62,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_DEEPSEEK2, "deepseek2" }, { LLM_ARCH_CHATGLM, "chatglm" }, { LLM_ARCH_GLM4, "glm4" }, + { LLM_ARCH_GLM4_MOE, "glm4moe" }, { LLM_ARCH_BITNET, "bitnet" }, { LLM_ARCH_T5, "t5" }, { LLM_ARCH_T5ENCODER, "t5encoder" }, @@ -127,6 +128,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_EXPERT_WEIGHTS_NORM, "%s.expert_weights_norm" }, { LLM_KV_EXPERT_GATING_FUNC, "%s.expert_gating_func" }, { LLM_KV_MOE_EVERY_N_LAYERS, "%s.moe_every_n_layers" }, + { LLM_KV_NEXTN_PREDICT_LAYERS, "%s.nextn_predict_layers" }, { LLM_KV_POOLING_TYPE, "%s.pooling_type" }, { LLM_KV_LOGIT_SCALE, "%s.logit_scale" }, { LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" }, @@ -1391,6 +1393,40 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, }, }, + { + LLM_ARCH_GLM4_MOE, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, + { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, + { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, + { LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" }, + // NextN/MTP tensors - preserved but unused (in final layer, dynamic layer number) + { LLM_TENSOR_NEXTN_EH_PROJ, "blk.%d.nextn.eh_proj" }, + { LLM_TENSOR_NEXTN_EMBED_TOKENS, "blk.%d.nextn.embed_tokens" }, + { LLM_TENSOR_NEXTN_ENORM, "blk.%d.nextn.enorm" }, + { LLM_TENSOR_NEXTN_HNORM, "blk.%d.nextn.hnorm" }, + { LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "blk.%d.nextn.shared_head_head" }, + { LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "blk.%d.nextn.shared_head_norm" }, + }, + }, { LLM_ARCH_BITNET, { @@ -2181,6 +2217,14 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_SHORTCONV_CONV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_CONV}}, {LLM_TENSOR_SHORTCONV_INPROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_SHORTCONV_OUTPROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + // NextN/MTP tensors are currently ignored (reserved for future MTP support) + // These tensors only exist in the last layer(s) and are treated as output tensors + {LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_NEXTN_EMBED_TOKENS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}}, + {LLM_TENSOR_NEXTN_ENORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}}, + {LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, + {LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, }; LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {} diff --git a/src/llama-arch.h b/src/llama-arch.h index 9b8bd65b2..456eb8d8c 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -66,6 +66,7 @@ enum llm_arch { LLM_ARCH_DEEPSEEK2, LLM_ARCH_CHATGLM, LLM_ARCH_GLM4, + LLM_ARCH_GLM4_MOE, LLM_ARCH_BITNET, LLM_ARCH_T5, LLM_ARCH_T5ENCODER, @@ -131,6 +132,7 @@ enum llm_kv { LLM_KV_EXPERT_WEIGHTS_NORM, LLM_KV_EXPERT_GATING_FUNC, LLM_KV_MOE_EVERY_N_LAYERS, + LLM_KV_NEXTN_PREDICT_LAYERS, LLM_KV_POOLING_TYPE, LLM_KV_LOGIT_SCALE, LLM_KV_DECODER_START_TOKEN_ID, @@ -409,6 +411,12 @@ enum llm_tensor { LLM_TENSOR_SHORTCONV_CONV, LLM_TENSOR_SHORTCONV_INPROJ, LLM_TENSOR_SHORTCONV_OUTPROJ, + LLM_TENSOR_NEXTN_EH_PROJ, + LLM_TENSOR_NEXTN_EMBED_TOKENS, + LLM_TENSOR_NEXTN_ENORM, + LLM_TENSOR_NEXTN_HNORM, + LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, + LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, }; enum llm_tensor_layer { diff --git a/src/llama-context.cpp b/src/llama-context.cpp index e1ecdd4cb..7c047f54a 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -786,7 +786,7 @@ int llama_context::encode(const llama_batch & batch_inp) { const auto & hparams = model.hparams; const int64_t n_embd = hparams.n_embd; - const int32_t n_vocab = model.vocab.n_tokens(); + const int64_t n_vocab = model.vocab.n_tokens(); // note: during encode, we always pass the full sequence starting from pos = 0 if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) { @@ -959,7 +959,7 @@ int llama_context::decode(const llama_batch & batch_inp) { const auto & vocab = model.vocab; const auto & hparams = model.hparams; - const int32_t n_vocab = vocab.n_tokens(); + const int64_t n_vocab = vocab.n_tokens(); const int64_t n_embd = hparams.n_embd; // when computing embeddings, all tokens are output @@ -1328,21 +1328,21 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { } void llama_context::output_reorder() { - const uint32_t n_vocab = model.vocab.n_tokens(); + const uint64_t n_vocab = model.vocab.n_tokens(); const uint64_t n_embd = model.hparams.n_embd; - for (uint32_t s = 0; s < output_swaps.size(); ++s) { - const uint32_t i0 = output_swaps[s].i0; - const uint32_t i1 = output_swaps[s].i1; + for (size_t s = 0; s < output_swaps.size(); ++s) { + const uint64_t i0 = output_swaps[s].i0; + const uint64_t i1 = output_swaps[s].i1; if (logits_size > 0) { - for (uint32_t k = 0; k < n_vocab; k++) { + for (uint64_t k = 0; k < n_vocab; k++) { std::swap(logits[i0*n_vocab + k], logits[i1*n_vocab + k]); } } if (embd_size > 0) { - for (uint32_t k = 0; k < n_embd; k++) { + for (uint64_t k = 0; k < n_embd; k++) { std::swap(embd[i0*n_embd + k], embd[i1*n_embd + k]); } } diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 491a26b63..9c15e8324 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -749,8 +749,8 @@ ggml_tensor * llm_graph_context::build_ffn( if (down) { cur = build_lora_mm(down, cur); - if (arch == LLM_ARCH_GLM4) { - // GLM4 seems to have numerical issues with half-precision accumulators + if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) { + // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators ggml_mul_mat_set_prec(cur, GGML_PREC_F32); } } @@ -1391,8 +1391,8 @@ ggml_tensor * llm_graph_context::build_attn( if (wo) { cur = build_lora_mm(wo, cur); - if (arch == LLM_ARCH_GLM4) { - // GLM4 seems to have numerical issues with half-precision accumulators + if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) { + // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators ggml_mul_mat_set_prec(cur, GGML_PREC_F32); } } diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 8b7e2a113..d60035726 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -73,6 +73,7 @@ struct llama_hparams { bool expert_weights_norm = false; uint32_t expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_NONE; uint32_t moe_every_n_layers = 0; + uint32_t nextn_predict_layers = 0; float f_norm_eps; float f_norm_rms_eps; diff --git a/src/llama-kv-cache-unified.cpp b/src/llama-kv-cache-unified.cpp index 2de13f250..caa253f8b 100644 --- a/src/llama-kv-cache-unified.cpp +++ b/src/llama-kv-cache-unified.cpp @@ -39,6 +39,10 @@ llama_kv_cache_unified::llama_kv_cache_unified( if (model.arch == LLM_ARCH_GEMMA3N) { n_layer_cache = 20; } + if (model.arch == LLM_ARCH_GLM4_MOE) { + // GLM-4.5: Only process up to last layer, skip final NextN layer + n_layer_cache = hparams.n_layer - hparams.nextn_predict_layers; + } // create a context for each buffer type std::map ctx_map; diff --git a/src/llama-model-loader.h b/src/llama-model-loader.h index 0f52b011b..c9189f6cb 100644 --- a/src/llama-model-loader.h +++ b/src/llama-model-loader.h @@ -58,8 +58,9 @@ struct llama_model_loader { } }; - static const int TENSOR_NOT_REQUIRED = 1; - static const int TENSOR_DUPLICATED = 2; + static const int TENSOR_NOT_REQUIRED = 1 << 0; + static const int TENSOR_DUPLICATED = 1 << 1; + static const int TENSOR_SKIP = 1 << 2; int n_kv = 0; int n_tensors = 0; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 104179553..97d07698d 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -114,8 +114,10 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_A13B: return "A13B"; case LLM_TYPE_21B_A3B: return "21B.A3B"; case LLM_TYPE_30B_A3B: return "30B.A3B"; + case LLM_TYPE_106B_A12B: return "106B.A12B"; case LLM_TYPE_235B_A22B: return "235B.A22B"; case LLM_TYPE_300B_A47B: return "300B.A47B"; + case LLM_TYPE_355B_A32B: return "355B.A32B"; case LLM_TYPE_E2B: return "E2B"; case LLM_TYPE_E4B: return "E4B"; default: return "?B"; @@ -1439,6 +1441,34 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_GLM4_MOE: + { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + // MoE parameters + ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert); + ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + + // Expert gating function (GLM-4.5 uses sigmoid) + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); + if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) { + hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID; + } + + // NextN/MTP parameters + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + + switch (hparams.n_layer) { + case 47: type = LLM_TYPE_106B_A12B; break; // GLM-4.5-Air (46 layers + 1 NextN layer) + case 93: type = LLM_TYPE_355B_A32B; break; // GLM-4.5 (92 layers + 1 NextN layer) + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_BITNET: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -1961,6 +1991,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { const auto TENSOR_DUPLICATED = llama_model_loader::TENSOR_DUPLICATED; const auto TENSOR_NOT_REQUIRED = llama_model_loader::TENSOR_NOT_REQUIRED; + const auto TENSOR_SKIP = llama_model_loader::TENSOR_SKIP; // create tensors for the weights { @@ -2020,7 +2051,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } // skip unused tensors - if (info.op == GGML_OP_NONE) { + if (info.op == GGML_OP_NONE || flags & TENSOR_SKIP) { const size_t nbytes = ggml_nbytes(t_meta); LLAMA_LOG_WARN("model has unused tensor %s (size = %zu bytes) -- ignoring\n", tn.str().c_str(), nbytes); @@ -4523,6 +4554,105 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); } } break; + case LLM_ARCH_GLM4_MOE: + { + const int64_t n_expert = hparams.n_expert; + const int64_t n_expert_used = hparams.n_expert_used; + const int64_t n_expert_shared = hparams.n_expert_shared; + + GGML_ASSERT(hparams.n_expert > 0 && "n_expert must be > 0 for GLM4_MOE MoE layers"); + GGML_ASSERT(hparams.n_expert_used > 0 && "n_expert_used must be > 0 for GLM4_MOE MoE layers"); + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); + } + + // Load ALL tensors including NextN layer to satisfy total tensor count + // but only PROCESS up to last layer (skipping final NextN layer) in forward pass + for (int i = 0; i < n_layer; ++i) { + int flags = 0; + if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { + // skip all tensors in the NextN layers + flags |= TENSOR_SKIP; + } + + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, flags); + + // GLM-style attention with bias terms + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, flags); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, flags); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, flags); + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), { n_embd_head_k * n_head }, flags); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), { n_embd_k_gqa }, flags); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), { n_embd_v_gqa }, flags); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, flags); + + // K/Q norm tensors (optional for GLM-4.5 355B variant) + layer.attn_q_norm = create_tensor( + tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, TENSOR_NOT_REQUIRED | flags); + layer.attn_k_norm = create_tensor( + tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, TENSOR_NOT_REQUIRED | flags); + + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, flags); + + // Check if this layer uses MoE or dense FFN based on n_layer_dense_lead + // GLM 4.5 uses hybrid architecture: layer 0 is dense, layers 1+ are MoE + const bool use_moe = (static_cast(i) >= hparams.n_layer_dense_lead); + + if (use_moe) { + // MoE layers + layer.ffn_gate_inp = + create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, flags); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), { n_expert }, flags); + + // MoE branch + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + + layer.ffn_gate_exps = create_tensor( + tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, flags); + layer.ffn_down_exps = create_tensor( + tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, flags); + layer.ffn_up_exps = create_tensor( + tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, flags); + + // Shared expert + if (n_expert_shared > 0) { + const int64_t n_ff_shexp = n_ff_exp * n_expert_shared; + layer.ffn_gate_shexp = create_tensor( + tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp }, flags); + layer.ffn_down_shexp = create_tensor( + tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_shexp, n_embd }, flags); + layer.ffn_up_shexp = create_tensor( + tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp }, flags); + } + } else { + // Dense layers (first k layers) - GLM uses separate gate/up projections + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, flags); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, flags); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, flags); + } + + // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers + if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags); + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, flags); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, flags); + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, flags); + } + } + } + break; case LLM_ARCH_NEMOTRON: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -13664,6 +13794,169 @@ struct llm_build_glm4 : public llm_graph_context { } }; +struct llm_build_glm4_moe : public llm_graph_context { + llm_build_glm4_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + // Only process up to last layer (skip final NextN layer) + // Final layer tensors are loaded but not processed in forward pass + const int n_transformer_layers = n_layer - hparams.nextn_predict_layers; + for (int il = 0; il < n_transformer_layers; ++il) { + ggml_tensor * inpSA = inpL; + + // Pre-attention norm + cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + } + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + } + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + } + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + // Apply Q/K norm if available (GLM-4.5 355B variant) + if (model.layers[il].attn_q_norm) { + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + } + if (model.layers[il].attn_k_norm) { + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + } + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_transformer_layers - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // Post-attention norm + cur = build_norm(ffn_inp, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "post_attn_norm", il); + + // Check if this is a dense layer (n_layer_dense_lead=1, so layer 0 is dense) + if (static_cast(il) < hparams.n_layer_dense_lead) { + // Dense FFN layer + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } else { + // MoE layer with shared experts + const int64_t n_expert = hparams.n_expert; + const int64_t n_expert_used = hparams.n_expert_used; + + // Process routed experts using existing MoE infrastructure + ggml_tensor * routed_out = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + model.layers[il].ffn_exp_probs_b, + n_expert, n_expert_used, + LLM_FFN_SILU, hparams.expert_weights_norm, + true, hparams.expert_weights_scale, + (llama_expert_gating_func_type) hparams.expert_gating_func, + il); + cb(routed_out, "ffn_moe_out", il); + + // Process shared expert on original input + ggml_tensor * shared_out = build_ffn(cur, + model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(shared_out, "ffn_shexp_out", il); + + // Final output: routed_output + shared_output + cur = ggml_add(ctx0, routed_out, shared_out); + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + struct llm_build_nemotron : public llm_graph_context { llm_build_nemotron(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; @@ -17977,6 +18270,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_GLM4_MOE: + { + llm = std::make_unique(*this, params); + } break; case LLM_ARCH_BITNET: { llm = std::make_unique(*this, params); @@ -18308,6 +18605,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_HUNYUAN_DENSE: case LLM_ARCH_LFM2: case LLM_ARCH_SMALLTHINKER: + case LLM_ARCH_GLM4_MOE: return LLAMA_ROPE_TYPE_NEOX; case LLM_ARCH_QWEN2VL: diff --git a/src/llama-model.h b/src/llama-model.h index 094e23808..bdb81cecd 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -101,8 +101,10 @@ enum llm_type { LLM_TYPE_A13B, LLM_TYPE_21B_A3B, // Ernie MoE small LLM_TYPE_30B_A3B, + LLM_TYPE_106B_A12B, // GLM-4.5-Air LLM_TYPE_235B_A22B, LLM_TYPE_300B_A47B, // Ernie MoE big + LLM_TYPE_355B_A32B, // GLM-4.5 LLM_TYPE_E2B, LLM_TYPE_E4B, }; @@ -166,6 +168,15 @@ struct llama_layer_shortconv { struct ggml_tensor * out_proj = nullptr; }; +struct llama_layer_nextn { + struct ggml_tensor * eh_proj = nullptr; + struct ggml_tensor * embed_tokens = nullptr; + struct ggml_tensor * enorm = nullptr; + struct ggml_tensor * hnorm = nullptr; + struct ggml_tensor * shared_head_head = nullptr; + struct ggml_tensor * shared_head_norm = nullptr; +}; + struct llama_layer { // normalization struct ggml_tensor * attn_norm = nullptr; @@ -354,6 +365,8 @@ struct llama_layer { struct llama_layer_convnext convnext; struct llama_layer_shortconv shortconv; + + struct llama_layer_nextn nextn; }; struct llama_model { diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index 839931f46..ee7a620f6 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -2430,6 +2430,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { || t.first == "<|fim▁begin|>" // DeepSeek || t.first == "
"
                         || t.first == "▁
"          // CodeLlama
+                        || t.first == "<|code_prefix|>" // GLM-4.5
                         ) {
                     special_fim_pre_id = t.second;
                     if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
@@ -2449,6 +2450,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                         || t.first == "<|fim▁hole|>" // DeepSeek
                         || t.first == ""
                         || t.first == "▁"         // CodeLlama
+                        || t.first == "<|code_suffix|>" // GLM-4.5
                         ) {
                     special_fim_suf_id = t.second;
                     if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
@@ -2468,6 +2470,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                         || t.first == "<|fim▁end|>"  // DeepSeek
                         || t.first == ""
                         || t.first == "▁"         // CodeLlama
+                        || t.first == "<|code_middle|>" // GLM-4.5
                         ) {
                     special_fim_mid_id = t.second;
                     if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
diff --git a/tools/quantize/quantize.cpp b/tools/quantize/quantize.cpp
index 26457680c..6ee32ad6b 100644
--- a/tools/quantize/quantize.cpp
+++ b/tools/quantize/quantize.cpp
@@ -612,7 +612,7 @@ int main(int argc, char ** argv) {
             return 1;
         }
         if (!try_parse_ftype(argv[arg_idx], params.ftype, ftype_str)) {
-            fprintf(stderr, "%s: invalid ftype '%s'\n", __func__, argv[3]);
+            fprintf(stderr, "%s: invalid ftype '%s'\n", __func__, argv[arg_idx]);
             return 1;
         }
         if (ftype_str == "COPY") {