From 8cf6b42d467d05fa7d9776d2bcc69974ecce6900 Mon Sep 17 00:00:00 2001 From: matteo Date: Thu, 23 Oct 2025 11:32:24 +0200 Subject: [PATCH 1/7] server : send partial stop string when is reached (#15007) --- tools/server/server.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 8737fba12..85849e160 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -2839,7 +2839,7 @@ struct server_context { slot.generated_text.begin() + pos + stop_pos, slot.generated_text.end()); pos = std::min(slot.n_sent_text, slot.generated_text.size()); - } else if (slot.has_next_token) { + } else if (slot.has_next_token && !llama_vocab_is_eog(vocab, result.tok) ) { stop_pos = slot.find_stopping_strings(str_test, token_str.size(), false); send_text = stop_pos == std::string::npos; } From 061f0eff02dc9a82f7bd850db3bd70b8a0b5e87a Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Thu, 23 Oct 2025 19:14:06 +0800 Subject: [PATCH 2/7] ggml-cuda: use passed ops instead of hardcoded ops (#16712) --- ggml/src/ggml-cuda/ggml-cuda.cu | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 6e7c5aedb..f5a6a751a 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2826,7 +2826,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, ggml_cuda_topk_moe_ops(/*with_norm=*/false, /*delayed_softmax=*/true); if (ops.size() == topk_moe_ops_with_norm.size() && - ggml_can_fuse_subgraph(cgraph, node_idx, topk_moe_ops_with_norm, { node_idx + 3, node_idx + 8 })) { + ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 8 })) { ggml_tensor * softmax = cgraph->nodes[node_idx]; ggml_tensor * weights = cgraph->nodes[node_idx+8]; @@ -2836,7 +2836,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, } if (ops.size() == topk_moe_ops.size() && - ggml_can_fuse_subgraph(cgraph, node_idx, topk_moe_ops, { node_idx + 3, node_idx + 4 })) { + ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 4 })) { ggml_tensor * softmax = cgraph->nodes[node_idx]; ggml_tensor * weights = cgraph->nodes[node_idx+4]; if (ggml_cuda_should_use_topk_moe(softmax, weights)) { @@ -2845,7 +2845,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, } if (ops.size() == topk_moe_ops_delayed_softmax.size() && - ggml_can_fuse_subgraph(cgraph, node_idx, topk_moe_ops_delayed_softmax, { node_idx + 2, node_idx + 5 })) { + ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 2, node_idx + 5 })) { ggml_tensor * softmax = cgraph->nodes[node_idx + 4]; ggml_tensor * weights = cgraph->nodes[node_idx + 5]; From fe6a9882acf5c02f96624ed8f80144100d7006cb Mon Sep 17 00:00:00 2001 From: Prajwal B Mehendarkar Date: Thu, 23 Oct 2025 17:07:31 +0530 Subject: [PATCH 3/7] Manually link -lbsd to resolve flock symbol on AIX (#16610) --- tools/imatrix/CMakeLists.txt | 5 +++++ tools/run/CMakeLists.txt | 6 ++++++ 2 files changed, 11 insertions(+) diff --git a/tools/imatrix/CMakeLists.txt b/tools/imatrix/CMakeLists.txt index 22f2fe5fd..5af6263f9 100644 --- a/tools/imatrix/CMakeLists.txt +++ b/tools/imatrix/CMakeLists.txt @@ -6,3 +6,8 @@ target_compile_features(${TARGET} PRIVATE cxx_std_17) if(LLAMA_TOOLS_INSTALL) install(TARGETS ${TARGET} RUNTIME) endif() + +if (CMAKE_SYSTEM_NAME MATCHES "AIX") + # AIX's flock() function comes from libbsd.a + target_link_libraries(${TARGET} PRIVATE -lbsd) +endif() diff --git a/tools/run/CMakeLists.txt b/tools/run/CMakeLists.txt index e52294ccc..6ad7534e2 100644 --- a/tools/run/CMakeLists.txt +++ b/tools/run/CMakeLists.txt @@ -13,5 +13,11 @@ endif () if(LLAMA_TOOLS_INSTALL) install(TARGETS ${TARGET} RUNTIME) endif() + +if (CMAKE_SYSTEM_NAME MATCHES "AIX") + # AIX's flock() function comes from libbsd.a + target_link_libraries(${TARGET} PRIVATE -lbsd) +endif() + target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT} ${LLAMA_RUN_EXTRA_LIBS}) target_compile_features(${TARGET} PRIVATE cxx_std_17) From d0660f237a5c31771a3d6d1030ebe3e0c409ba92 Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Thu, 23 Oct 2025 15:00:49 +0200 Subject: [PATCH 4/7] mtmd-cli : allow using --jinja (#16718) * mtmd-cli : allow using --jinja * support -sys * implement chat_history * fix clear memory * rm -sys support, added TODO --- common/arg.cpp | 2 +- tools/mtmd/mtmd-cli.cpp | 47 ++++++++++++++++++++++++++++------------- 2 files changed, 33 insertions(+), 16 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index 33ed7ae85..a25743c89 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -3435,7 +3435,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params) { params.use_jinja = true; } - ).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_JINJA")); + ).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_MTMD}).set_env("LLAMA_ARG_JINJA")); add_opt(common_arg( {"--reasoning-format"}, "FORMAT", "controls whether thought tags are allowed and/or extracted from the response, and in which format they're returned; one of:\n" diff --git a/tools/mtmd/mtmd-cli.cpp b/tools/mtmd/mtmd-cli.cpp index 5fde6ca0c..fd1fb6581 100644 --- a/tools/mtmd/mtmd-cli.cpp +++ b/tools/mtmd/mtmd-cli.cpp @@ -76,9 +76,11 @@ struct mtmd_cli_context { mtmd::bitmaps bitmaps; - // note: we know that gemma3 template is "linear", meaning each turn is completely separated to another - // so here we don't need to keep track of chat history + // chat template common_chat_templates_ptr tmpls; + std::vector chat_history; + bool use_jinja = false; + // TODO: support for --system-prompt with /clear command // support for legacy templates (models not having EOT token) llama_tokens antiprompt_tokens; @@ -108,6 +110,8 @@ struct mtmd_cli_context { } tmpls = common_chat_templates_init(model, params.chat_template); + use_jinja = params.use_jinja; + chat_history.clear(); LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(tmpls.get(), params.use_jinja, params.default_template_kwargs).c_str()); init_vision_context(params); @@ -193,19 +197,33 @@ static int generate_response(mtmd_cli_context & ctx, int n_predict) { return 1; } } + + std::string generated_text = common_detokenize(ctx.lctx, generated_tokens); + common_chat_msg msg; + msg.role = "assistant"; + msg.content = generated_text; + ctx.chat_history.push_back(std::move(msg)); + return 0; } -static int eval_message(mtmd_cli_context & ctx, common_chat_msg & msg, bool add_bos = false) { - common_chat_templates_inputs tmpl_inputs; - tmpl_inputs.messages = {msg}; - tmpl_inputs.add_generation_prompt = true; - tmpl_inputs.use_jinja = false; // jinja is buggy here - auto formatted_chat = common_chat_templates_apply(ctx.tmpls.get(), tmpl_inputs); - LOG_DBG("formatted_chat.prompt: %s\n", formatted_chat.prompt.c_str()); +static std::string chat_add_and_format(mtmd_cli_context & ctx, common_chat_msg & new_msg) { + LOG_DBG("chat_add_and_format: new_msg.role='%s', new_msg.content='%s'\n", + new_msg.role.c_str(), new_msg.content.c_str()); + auto formatted = common_chat_format_single(ctx.tmpls.get(), ctx.chat_history, + new_msg, new_msg.role == "user", + ctx.use_jinja); + ctx.chat_history.push_back(new_msg); + return formatted; +} + +static int eval_message(mtmd_cli_context & ctx, common_chat_msg & msg) { + bool add_bos = ctx.chat_history.empty(); + auto formatted_chat = chat_add_and_format(ctx, msg); + LOG_DBG("formatted_chat.prompt: %s\n", formatted_chat.c_str()); mtmd_input_text text; - text.text = formatted_chat.prompt.c_str(); + text.text = formatted_chat.c_str(); text.add_special = add_bos; text.parse_special = true; @@ -303,7 +321,7 @@ int main(int argc, char ** argv) { return 1; // error is already printed by libmtmd } } - if (eval_message(ctx, msg, true)) { + if (eval_message(ctx, msg)) { return 1; } if (!g_is_interrupted && generate_response(ctx, n_predict)) { @@ -322,7 +340,6 @@ int main(int argc, char ** argv) { LOG("\n /quit or /exit exit the program"); LOG("\n"); - bool is_first_msg = true; std::string content; while (!g_is_interrupted) { @@ -342,7 +359,8 @@ int main(int argc, char ** argv) { } if (line == "/clear") { ctx.n_past = 0; - llama_memory_seq_rm(llama_get_memory(ctx.lctx), 0, 1, -1); // keep BOS + ctx.chat_history.clear(); + llama_memory_clear(llama_get_memory(ctx.lctx), true); LOG("Chat history cleared\n\n"); continue; } @@ -367,7 +385,7 @@ int main(int argc, char ** argv) { common_chat_msg msg; msg.role = "user"; msg.content = content; - int ret = eval_message(ctx, msg, is_first_msg); + int ret = eval_message(ctx, msg); if (ret) { return 1; } @@ -376,7 +394,6 @@ int main(int argc, char ** argv) { return 1; } content.clear(); - is_first_msg = false; } } if (g_is_interrupted) LOG("\nInterrupted by user\n"); From dd62dcfab97e420949519fd0eac9fca7bf97e635 Mon Sep 17 00:00:00 2001 From: Julien Denize <40604584+juliendenize@users.noreply.github.com> Date: Thu, 23 Oct 2025 15:54:46 +0200 Subject: [PATCH 5/7] convert : Make mistral-common dependency optional (#16738) * Make mistral-common dependency optional * Fix typing --- convert_hf_to_gguf.py | 43 +++++++++++++++---- gguf-py/gguf/vocab.py | 8 ++-- .../requirements-convert_hf_to_gguf.txt | 2 - 3 files changed, 38 insertions(+), 15 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index ed99dc847..7b49969c0 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -29,12 +29,29 @@ if 'NO_LOCAL_GGUF' not in os.environ: sys.path.insert(1, str(Path(__file__).parent / 'gguf-py')) import gguf from gguf.vocab import MistralTokenizerType, MistralVocab -from mistral_common.tokens.tokenizers.base import TokenizerVersion -from mistral_common.tokens.tokenizers.multimodal import DATASET_MEAN, DATASET_STD -from mistral_common.tokens.tokenizers.tekken import Tekkenizer -from mistral_common.tokens.tokenizers.sentencepiece import ( - SentencePieceTokenizer, -) + +try: + from mistral_common.tokens.tokenizers.base import TokenizerVersion # pyright: ignore[reportMissingImports] + from mistral_common.tokens.tokenizers.multimodal import DATASET_MEAN as _MISTRAL_COMMON_DATASET_MEAN, DATASET_STD as _MISTRAL_COMMON_DATASET_STD # pyright: ignore[reportMissingImports] + from mistral_common.tokens.tokenizers.tekken import Tekkenizer # pyright: ignore[reportMissingImports] + from mistral_common.tokens.tokenizers.sentencepiece import ( # pyright: ignore[reportMissingImports] + SentencePieceTokenizer, + ) + + _mistral_common_installed = True + _mistral_import_error_msg = "" +except ImportError: + _MISTRAL_COMMON_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) + _MISTRAL_COMMON_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) + + _mistral_common_installed = False + TokenizerVersion = None + Tekkenizer = None + SentencePieceTokenizer = None + _mistral_import_error_msg = ( + "Mistral format requires `mistral-common` to be installed. Please run " + "`pip install mistral-common[image,audio]` to install it." + ) logger = logging.getLogger("hf-to-gguf") @@ -107,6 +124,9 @@ class ModelBase: type(self) is MmprojModel: raise TypeError(f"{type(self).__name__!r} should not be directly instantiated") + if self.is_mistral_format and not _mistral_common_installed: + raise ImportError(_mistral_import_error_msg) + self.dir_model = dir_model self.ftype = ftype self.fname_out = fname_out @@ -1363,8 +1383,8 @@ class MmprojModel(ModelBase): self.gguf_writer.add_vision_head_count(self.find_vparam(["num_attention_heads"])) # preprocessor config - image_mean = DATASET_MEAN if self.is_mistral_format else self.preprocessor_config["image_mean"] - image_std = DATASET_STD if self.is_mistral_format else self.preprocessor_config["image_std"] + image_mean = _MISTRAL_COMMON_DATASET_MEAN if self.is_mistral_format else self.preprocessor_config["image_mean"] + image_std = _MISTRAL_COMMON_DATASET_STD if self.is_mistral_format else self.preprocessor_config["image_std"] self.gguf_writer.add_vision_image_mean(image_mean) self.gguf_writer.add_vision_image_std(image_std) @@ -2033,6 +2053,9 @@ class LlamaModel(TextModel): self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 32) def _set_vocab_mistral(self): + if not _mistral_common_installed: + raise ImportError(_mistral_import_error_msg) + vocab = MistralVocab(self.dir_model) logger.info( f"Converting tokenizer {vocab.tokenizer_type} of size {vocab.vocab_size}." @@ -9212,7 +9235,7 @@ class MistralModel(LlamaModel): @staticmethod def get_community_chat_template(vocab: MistralVocab, templates_dir: Path, is_mistral_format: bool): - assert TokenizerVersion is not None, "mistral_common is not installed" + assert TokenizerVersion is not None and Tekkenizer is not None and SentencePieceTokenizer is not None, _mistral_import_error_msg assert isinstance(vocab.tokenizer, (Tekkenizer, SentencePieceTokenizer)), ( f"Expected Tekkenizer or SentencePieceTokenizer, got {type(vocab.tokenizer)}" ) @@ -9594,6 +9617,8 @@ def main() -> None: fname_out = ModelBase.add_prefix_to_filename(fname_out, "mmproj-") is_mistral_format = args.mistral_format + if is_mistral_format and not _mistral_common_installed: + raise ImportError(_mistral_import_error_msg) disable_mistral_community_chat_template = args.disable_mistral_community_chat_template with torch.inference_mode(): diff --git a/gguf-py/gguf/vocab.py b/gguf-py/gguf/vocab.py index 7111557bf..5c6817109 100644 --- a/gguf-py/gguf/vocab.py +++ b/gguf-py/gguf/vocab.py @@ -14,12 +14,12 @@ except ImportError: SentencePieceProcessor = None try: - from mistral_common.tokens.tokenizers.mistral import MistralTokenizer - from mistral_common.tokens.tokenizers.tekken import Tekkenizer - from mistral_common.tokens.tokenizers.utils import ( + from mistral_common.tokens.tokenizers.mistral import MistralTokenizer # pyright: ignore[reportMissingImports] + from mistral_common.tokens.tokenizers.tekken import Tekkenizer # pyright: ignore[reportMissingImports] + from mistral_common.tokens.tokenizers.utils import ( # pyright: ignore[reportMissingImports] _filter_valid_tokenizer_files, ) - from mistral_common.tokens.tokenizers.sentencepiece import ( + from mistral_common.tokens.tokenizers.sentencepiece import ( # pyright: ignore[reportMissingImports] SentencePieceTokenizer, ) except ImportError: diff --git a/requirements/requirements-convert_hf_to_gguf.txt b/requirements/requirements-convert_hf_to_gguf.txt index 90c98c3ff..122b4788d 100644 --- a/requirements/requirements-convert_hf_to_gguf.txt +++ b/requirements/requirements-convert_hf_to_gguf.txt @@ -1,5 +1,3 @@ -mistral-common>=1.8.3 - -r ./requirements-convert_legacy_llama.txt --extra-index-url https://download.pytorch.org/whl/cpu From 0bf47a1dbba4d36f2aff4e8c34b06210ba34e688 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Thu, 23 Oct 2025 21:30:17 +0200 Subject: [PATCH 6/7] server: add memory breakdown print (#16740) --- tools/server/server.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 85849e160..4124bffa4 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -5714,6 +5714,7 @@ int main(int argc, char ** argv) { clean_up(); t.join(); + llama_memory_breakdown_print(ctx_server.ctx); return 0; } From f8f071faddf32ea09f4234edb6e809b380a9ee26 Mon Sep 17 00:00:00 2001 From: compilade Date: Thu, 23 Oct 2025 16:31:41 -0400 Subject: [PATCH 7/7] convert : handle pre-quantized models (#14810) * convert : begin handling pre-quantized models * convert : fix conversion from FP8 for Deepseek-V3.1-Base --- convert_hf_to_gguf.py | 244 ++++++++++++++++++++++++++++++------------ 1 file changed, 177 insertions(+), 67 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 7b49969c0..3e3db999c 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -90,10 +90,8 @@ class ModelBase: use_temp_file: bool lazy: bool dry_run: bool - part_names: list[str] - is_safetensors: bool hparams: dict[str, Any] - tensor_names: set[str] | None + model_tensors: dict[str, Callable[[], Tensor]] gguf_writer: gguf.GGUFWriter model_name: str | None metadata_override: Path | None @@ -137,25 +135,8 @@ class ModelBase: self.dry_run = dry_run self.remote_hf_model_id = remote_hf_model_id self.sentence_transformers_dense_modules = sentence_transformers_dense_modules - if remote_hf_model_id is not None: - self.is_safetensors = True - - def get_remote_tensors() -> Iterator[tuple[str, Tensor]]: - logger.info(f"Using remote model with HuggingFace id: {remote_hf_model_id}") - remote_tensors = gguf.utility.SafetensorRemote.get_list_tensors_hf_model(remote_hf_model_id) - self.tensor_names = set(name for name in remote_tensors.keys()) - for name, remote_tensor in remote_tensors.items(): - yield (name, LazyTorchTensor.from_remote_tensor(remote_tensor)) - - self.get_tensors = get_remote_tensors - else: - prefix = "model" if not self.is_mistral_format else "consolidated" - self.part_names = ModelBase.get_model_part_names(self.dir_model, prefix, ".safetensors") - self.is_safetensors = len(self.part_names) > 0 - if not self.is_safetensors: - self.part_names = ModelBase.get_model_part_names(self.dir_model, "pytorch_model", ".bin") self.hparams = ModelBase.load_hparams(self.dir_model, self.is_mistral_format) if hparams is None else hparams - self.tensor_names = None + self.model_tensors = self.index_tensors(remote_hf_model_id=remote_hf_model_id) self.metadata_override = metadata_override self.model_name = model_name self.dir_model_card = dir_model # overridden in convert_lora_to_gguf.py @@ -171,6 +152,8 @@ class ModelBase: logger.info(f"choosing --outtype bf16 from first tensor type ({first_tensor.dtype})") self.ftype = gguf.LlamaFileType.MOSTLY_BF16 + self.dequant_model() + # Configure GGUF Writer self.gguf_writer = gguf.GGUFWriter(path=None, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file, split_max_tensors=split_max_tensors, split_max_size=split_max_size, dry_run=dry_run, small_first_shard=small_first_shard) @@ -192,67 +175,215 @@ class ModelBase: return None raise KeyError(f"could not find any of: {keys}") - def get_tensors(self) -> Iterator[tuple[str, Tensor]]: - tensor_names_from_parts: set[str] = set() + def index_tensors(self, remote_hf_model_id: str | None = None) -> dict[str, Callable[[], Tensor]]: + tensors: dict[str, Callable[[], Tensor]] = {} + + if remote_hf_model_id is not None: + is_safetensors = True + + logger.info(f"Using remote model with HuggingFace id: {remote_hf_model_id}") + remote_tensors = gguf.utility.SafetensorRemote.get_list_tensors_hf_model(remote_hf_model_id) + for name, remote_tensor in remote_tensors.items(): + tensors[name] = lambda r=remote_tensor: LazyTorchTensor.from_remote_tensor(r) + + return tensors + + prefix = "model" if not self.is_mistral_format else "consolidated" + part_names: list[str] = ModelBase.get_model_part_names(self.dir_model, prefix, ".safetensors") + is_safetensors: bool = len(part_names) > 0 + if not is_safetensors: + part_names = ModelBase.get_model_part_names(self.dir_model, "pytorch_model", ".bin") + + tensor_names_from_index: set[str] = set() if not self.is_mistral_format: - index_name = "model.safetensors" if self.is_safetensors else "pytorch_model.bin" + index_name = "model.safetensors" if is_safetensors else "pytorch_model.bin" index_name += ".index.json" index_file = self.dir_model / index_name if index_file.is_file(): - self.tensor_names = set() logger.info(f"gguf: loading model weight map from '{index_name}'") with open(index_file, "r", encoding="utf-8") as f: index: dict[str, Any] = json.load(f) weight_map = index.get("weight_map") if weight_map is None or not isinstance(weight_map, dict): raise ValueError(f"Can't load 'weight_map' from {index_name!r}") - self.tensor_names.update(weight_map.keys()) + tensor_names_from_index.update(weight_map.keys()) else: - self.tensor_names = tensor_names_from_parts weight_map = {} else: - self.tensor_names = tensor_names_from_parts weight_map = {} - for part_name in self.part_names: - logger.info(f"gguf: loading model part '{part_name}'") + for part_name in part_names: + logger.info(f"gguf: indexing model part '{part_name}'") ctx: ContextManager[Any] - if self.is_safetensors: + if is_safetensors: from safetensors import safe_open ctx = cast(ContextManager[Any], safe_open(self.dir_model / part_name, framework="pt", device="cpu")) else: ctx = contextlib.nullcontext(torch.load(str(self.dir_model / part_name), map_location="cpu", mmap=True, weights_only=True)) with ctx as model_part: - tensor_names_from_parts.update(model_part.keys()) + assert model_part is not None for name in model_part.keys(): - if self.is_safetensors: + if is_safetensors: if self.lazy: data = model_part.get_slice(name) - data = LazyTorchTensor.from_safetensors_slice(data) + data_gen = lambda data=data: LazyTorchTensor.from_safetensors_slice(data) # noqa: E731 else: data = model_part.get_tensor(name) + data_gen = lambda data=data: data # noqa: E731 else: data = model_part[name] if self.lazy: - data = LazyTorchTensor.from_eager(data) - yield name, data + data_gen = lambda data=data: LazyTorchTensor.from_eager(data) # noqa: E731 + else: + data_gen = lambda data=data: data # noqa: E731 + tensors[name] = data_gen # verify tensor name presence and identify potentially missing files - if len(tensor_names_from_parts.symmetric_difference(self.tensor_names)) > 0: - missing = sorted(self.tensor_names.difference(tensor_names_from_parts)) - extra = sorted(tensor_names_from_parts.difference(self.tensor_names)) - missing_files = sorted(set(weight_map[n] for n in missing if n in weight_map)) - if len(extra) == 0 and len(missing_files) > 0: - raise ValueError(f"Missing or incomplete model files: {missing_files}\n" - f"Missing tensors: {missing}") + if len(tensor_names_from_index) > 0: + tensor_names_from_parts = set(tensors.keys()) + if len(tensor_names_from_parts.symmetric_difference(tensor_names_from_index)) > 0: + missing = sorted(tensor_names_from_index.difference(tensor_names_from_parts)) + extra = sorted(tensor_names_from_parts.difference(tensor_names_from_index)) + missing_files = sorted(set(weight_map[n] for n in missing if n in weight_map)) + if len(extra) == 0 and len(missing_files) > 0: + raise ValueError(f"Missing or incomplete model files: {missing_files}\n" + f"Missing tensors: {missing}") + else: + raise ValueError("Mismatch between weight map and model parts for tensor names:\n" + f"Missing tensors: {missing}\n" + f"Extra tensors: {extra}") + + return tensors + + def dequant_model(self): + tensors_to_remove: list[str] = [] + new_tensors: dict[str, Callable[[], Tensor]] = {} + + if (quant_config := self.hparams.get("quantization_config")) and isinstance(quant_config, dict): + quant_method = quant_config.get("quant_method") + + def dequant_bitnet(weight: Tensor, scale: Tensor) -> Tensor: + weight = weight.view(torch.uint8) + orig_shape = weight.shape + + shift = torch.tensor([0, 2, 4, 6], dtype=torch.uint8).reshape((4, *(1 for _ in range(len(orig_shape))))) + data = weight.unsqueeze(0).expand((4, *orig_shape)) >> shift + data = data & 3 + data = (data.float() - 1).reshape((orig_shape[0] * 4, *orig_shape[1:])) + + # The scale is inverted + return data / scale.float() + + def dequant_simple(weight: Tensor, scale: Tensor) -> Tensor: + scale = scale.float() + + if (weight_block_size := quant_config.get("weight_block_size")): + # TODO: make sure it's a list of integers + for i, size in enumerate(weight_block_size): + scale = scale.repeat_interleave(size, i) + # unpad the scale (e.g. when the tensor size isn't a multiple of the block size) + scale = scale[tuple(slice(0, size) for size in weight.shape)] + + return weight.float() * scale + + # ref: https://github.com/ModelCloud/GPTQModel/blob/037c5c0f6c9e33c500d975b038d02e7ca437546d/gptqmodel/nn_modules/qlinear/__init__.py#L437-L476 + def dequant_gptq(g_idx: Tensor, qweight: Tensor, qzeros: Tensor, scales: Tensor) -> Tensor: + bits = quant_config["bits"] + assert bits in (2, 3, 4, 8) + assert qweight.dtype == qzeros.dtype + maxq = (2 ** bits) - 1 + weight = None + zeros = None + pack_dtype_bits = qweight.dtype.itemsize * 8 + + if bits in [2, 4, 8]: + pack_factor = pack_dtype_bits // bits + wf = torch.tensor(list(range(0, pack_dtype_bits, bits)), dtype=torch.int32).unsqueeze(0) + if self.lazy: + wf = LazyTorchTensor.from_eager(wf) + + zeros = torch.bitwise_right_shift( + qzeros.unsqueeze(2).expand(-1, -1, pack_factor), + wf.unsqueeze(0) + ).to(torch.int16 if bits == 8 else torch.int8) + zeros = torch.bitwise_and(zeros, maxq).reshape(scales.shape) + + weight = torch.bitwise_and( + torch.bitwise_right_shift( + qweight.unsqueeze(1).expand(-1, pack_factor, -1), + wf.unsqueeze(-1) + ).to(torch.int16 if bits == 8 else torch.int8), + maxq + ) + elif bits == 3: + raise NotImplementedError("3-bit gptq dequantization is not yet implemented") + + assert weight is not None + assert zeros is not None + + weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2]) + + # gptq_v2 doesn't need to offset zeros + if quant_config.get("checkpoint_format", "gptq") == "gptq": + zeros += 1 + + return (scales[g_idx].float() * (weight - zeros[g_idx]).float()).T + + if quant_method == "bitnet": + for name in self.model_tensors.keys(): + if name.endswith(".weight_scale"): + weight_name = name.removesuffix("_scale") + w = self.model_tensors[weight_name] + s = self.model_tensors[name] + self.model_tensors[weight_name] = lambda w=w, s=s: dequant_bitnet(w(), s()) + tensors_to_remove.append(name) + elif quant_method == "fp8": + for name in self.model_tensors.keys(): + if name.endswith(".weight_scale_inv"): + weight_name = name.removesuffix("_scale_inv") + w = self.model_tensors[weight_name] + s = self.model_tensors[name] + self.model_tensors[weight_name] = lambda w=w, s=s: dequant_simple(w(), s()) + tensors_to_remove.append(name) + elif quant_method == "gptq": + for name in self.model_tensors.keys(): + if name.endswith(".qweight"): + base_name = name.removesuffix(".qweight") + g_idx = self.model_tensors[base_name + ".g_idx"] + qweight = self.model_tensors[base_name + ".qweight"] + qzeros = self.model_tensors[base_name + ".qzeros"] + scales = self.model_tensors[base_name + ".scales"] + new_tensors[base_name + ".weight"] = ( + lambda g=g_idx, z=qzeros, w=qweight, s=scales: dequant_gptq( + g(), w(), z(), s() + ) + ) + tensors_to_remove += [ + base_name + n + for n in ( + ".g_idx", + ".qzeros", + ".qweight", + ".scales", + ) + ] else: - raise ValueError("Mismatch between weight map and model parts for tensor names:\n" - f"Missing tensors: {missing}\n" - f"Extra tensors: {extra}") + raise NotImplementedError(f"Quant method is not yet supported: {quant_method!r}") + + for name in tensors_to_remove: + if name in self.model_tensors: + del self.model_tensors[name] + + for name, value in new_tensors.items(): + self.model_tensors[name] = value + + def get_tensors(self) -> Iterator[tuple[str, Tensor]]: + for name, gen in self.model_tensors.items(): + yield name, gen() def format_tensor_name(self, key: gguf.MODEL_TENSOR, bid: int | None = None, suffix: str = ".weight") -> str: if key not in gguf.MODEL_TENSORS[self.model_arch]: @@ -4381,27 +4512,6 @@ class CodeShellModel(TextModel): self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) self.gguf_writer.add_rope_scaling_factor(1.0) - _has_tok_embd = False - - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - del bid # unused - - output_name = self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT) - tok_embd_name = self.format_tensor_name(gguf.MODEL_TENSOR.TOKEN_EMBD) - - new_name = self.map_tensor_name(name) - - # assuming token_embd.weight is seen before output.weight - if not self._has_tok_embd and new_name == self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT): - # even though the tensor file(s) does not contain the word embeddings they are still in the weight map - if self.tensor_names and "transformer.wte.weight" in self.tensor_names: - logger.debug(f"{tok_embd_name} not found before {output_name}, assuming they are tied") - self.tensor_names.remove("transformer.wte.weight") - elif new_name == tok_embd_name: - self._has_tok_embd = True - - return [(new_name, data_torch)] - @ModelBase.register("InternLM2ForCausalLM") class InternLM2Model(TextModel):