mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-05-17 04:09:19 +00:00
* move conversion code to a dedicated conversion directory and split the files akin to the src/models architecture --------- Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
333 lines
12 KiB
Python
333 lines
12 KiB
Python
from __future__ import annotations
|
|
|
|
from .base import (
|
|
ModelBase, TextModel, MmprojModel, ModelType, SentencePieceTokenTypes,
|
|
logger, _mistral_common_installed, _mistral_import_error_msg,
|
|
get_model_architecture, LazyTorchTensor,
|
|
)
|
|
from typing import Type
|
|
|
|
|
|
__all__ = [
|
|
"ModelBase", "TextModel", "MmprojModel", "ModelType", "SentencePieceTokenTypes",
|
|
"get_model_architecture", "LazyTorchTensor", "logger",
|
|
"_mistral_common_installed", "_mistral_import_error_msg",
|
|
"get_model_class", "print_registered_models", "load_all_models",
|
|
]
|
|
|
|
|
|
TEXT_MODEL_MAP: dict[str, str] = {
|
|
"AfmoeForCausalLM": "afmoe",
|
|
"ApertusForCausalLM": "llama",
|
|
"ArceeForCausalLM": "llama",
|
|
"ArcticForCausalLM": "arctic",
|
|
"AudioFlamingo3ForConditionalGeneration": "qwen",
|
|
"BaiChuanForCausalLM": "baichuan",
|
|
"BaichuanForCausalLM": "baichuan",
|
|
"BailingMoeForCausalLM": "bailingmoe",
|
|
"BailingMoeV2ForCausalLM": "bailingmoe",
|
|
"BambaForCausalLM": "granite",
|
|
"BertForMaskedLM": "bert",
|
|
"BertForSequenceClassification": "bert",
|
|
"BertModel": "bert",
|
|
"BitnetForCausalLM": "bitnet",
|
|
"BloomForCausalLM": "bloom",
|
|
"BloomModel": "bloom",
|
|
"CamembertModel": "bert",
|
|
"ChameleonForCausalLM": "chameleon",
|
|
"ChameleonForConditionalGeneration": "chameleon",
|
|
"ChatGLMForConditionalGeneration": "chatglm",
|
|
"ChatGLMModel": "chatglm",
|
|
"CodeShellForCausalLM": "codeshell",
|
|
"CogVLMForCausalLM": "cogvlm",
|
|
"Cohere2ForCausalLM": "command_r",
|
|
"CohereForCausalLM": "command_r",
|
|
"DbrxForCausalLM": "dbrx",
|
|
"DeciLMForCausalLM": "deci",
|
|
"DeepseekForCausalLM": "deepseek",
|
|
"DeepseekV2ForCausalLM": "deepseek",
|
|
"DeepseekV3ForCausalLM": "deepseek",
|
|
"DistilBertForMaskedLM": "bert",
|
|
"DistilBertForSequenceClassification": "bert",
|
|
"DistilBertModel": "bert",
|
|
"Dots1ForCausalLM": "dots1",
|
|
"DotsOCRForCausalLM": "qwen",
|
|
"DreamModel": "dream",
|
|
"Ernie4_5ForCausalLM": "ernie",
|
|
"Ernie4_5_ForCausalLM": "ernie",
|
|
"Ernie4_5_MoeForCausalLM": "ernie",
|
|
"EuroBertModel": "bert",
|
|
"Exaone4ForCausalLM": "exaone",
|
|
"ExaoneForCausalLM": "exaone",
|
|
"ExaoneMoEForCausalLM": "exaone",
|
|
"FalconForCausalLM": "falcon",
|
|
"FalconH1ForCausalLM": "falcon_h1",
|
|
"FalconMambaForCausalLM": "mamba",
|
|
"GPT2LMHeadModel": "gpt2",
|
|
"GPTBigCodeForCausalLM": "starcoder",
|
|
"GPTNeoXForCausalLM": "gptneox",
|
|
"GPTRefactForCausalLM": "refact",
|
|
"Gemma2ForCausalLM": "gemma",
|
|
"Gemma3ForCausalLM": "gemma",
|
|
"Gemma3ForConditionalGeneration": "gemma",
|
|
"Gemma3TextModel": "gemma",
|
|
"Gemma3nForCausalLM": "gemma",
|
|
"Gemma3nForConditionalGeneration": "gemma",
|
|
"Gemma4ForConditionalGeneration": "gemma",
|
|
"GemmaForCausalLM": "gemma",
|
|
"Glm4ForCausalLM": "glm",
|
|
"Glm4MoeForCausalLM": "glm",
|
|
"Glm4MoeLiteForCausalLM": "glm",
|
|
"Glm4vForConditionalGeneration": "glm",
|
|
"Glm4vMoeForConditionalGeneration": "glm",
|
|
"GlmForCausalLM": "chatglm",
|
|
"GlmMoeDsaForCausalLM": "glm",
|
|
"GlmOcrForConditionalGeneration": "glm",
|
|
"GptOssForCausalLM": "gpt_oss",
|
|
"GraniteForCausalLM": "granite",
|
|
"GraniteMoeForCausalLM": "granite",
|
|
"GraniteMoeHybridForCausalLM": "granite",
|
|
"GraniteMoeSharedForCausalLM": "granite",
|
|
"GraniteSpeechForConditionalGeneration": "granite",
|
|
"Grok1ForCausalLM": "grok",
|
|
"GrokForCausalLM": "grok",
|
|
"GroveMoeForCausalLM": "grovemoe",
|
|
"HunYuanDenseV1ForCausalLM": "hunyuan",
|
|
"HunYuanMoEV1ForCausalLM": "hunyuan",
|
|
"HunYuanVLForConditionalGeneration": "hunyuan",
|
|
"IQuestCoderForCausalLM": "llama",
|
|
"InternLM2ForCausalLM": "internlm",
|
|
"InternLM3ForCausalLM": "internlm",
|
|
"JAISLMHeadModel": "jais",
|
|
"Jais2ForCausalLM": "jais",
|
|
"JambaForCausalLM": "jamba",
|
|
"JanusForConditionalGeneration": "januspro",
|
|
"JinaBertForMaskedLM": "bert",
|
|
"JinaBertModel": "bert",
|
|
"JinaEmbeddingsV5Model": "bert",
|
|
"KORMoForCausalLM": "qwen",
|
|
"KimiK25ForConditionalGeneration": "deepseek",
|
|
"KimiLinearForCausalLM": "kimi_linear",
|
|
"KimiLinearModel": "kimi_linear",
|
|
"KimiVLForConditionalGeneration": "deepseek",
|
|
"LFM2ForCausalLM": "lfm2",
|
|
"LLaDAMoEModel": "llada",
|
|
"LLaDAMoEModelLM": "llada",
|
|
"LLaDAModelLM": "llada",
|
|
"LLaMAForCausalLM": "llama",
|
|
"Lfm25AudioTokenizer": "lfm2",
|
|
"Lfm2ForCausalLM": "lfm2",
|
|
"Lfm2Model": "lfm2",
|
|
"Lfm2MoeForCausalLM": "lfm2",
|
|
"Llama4ForCausalLM": "llama",
|
|
"Llama4ForConditionalGeneration": "llama",
|
|
"LlamaBidirectionalModel": "llama",
|
|
"LlamaForCausalLM": "llama",
|
|
"LlamaModel": "llama",
|
|
"LlavaForConditionalGeneration": "llama",
|
|
"LlavaStableLMEpochForCausalLM": "stablelm",
|
|
"MPTForCausalLM": "mpt",
|
|
"MT5ForConditionalGeneration": "t5",
|
|
"MaincoderForCausalLM": "maincoder",
|
|
"Mamba2ForCausalLM": "mamba",
|
|
"MambaForCausalLM": "mamba",
|
|
"MambaLMHeadModel": "mamba",
|
|
"MiMoV2FlashForCausalLM": "mimo",
|
|
"MiMoV2ForCausalLM": "mimo",
|
|
"MiniCPM3ForCausalLM": "minicpm",
|
|
"MiniCPMForCausalLM": "minicpm",
|
|
"MiniCPMV4_6ForConditionalGeneration": "minicpm",
|
|
"MiniMaxM2ForCausalLM": "minimax",
|
|
"Ministral3ForCausalLM": "mistral3",
|
|
"Mistral3ForConditionalGeneration": "mistral3",
|
|
"MistralForCausalLM": "llama",
|
|
"MixtralForCausalLM": "llama",
|
|
"ModernBertForMaskedLM": "bert",
|
|
"ModernBertForSequenceClassification": "bert",
|
|
"ModernBertModel": "bert",
|
|
"NemotronForCausalLM": "nemotron",
|
|
"NemotronHForCausalLM": "nemotron",
|
|
"NeoBERT": "bert",
|
|
"NeoBERTForSequenceClassification": "bert",
|
|
"NeoBERTLMHead": "bert",
|
|
"NomicBertModel": "bert",
|
|
"OLMoForCausalLM": "olmo",
|
|
"Olmo2ForCausalLM": "olmo",
|
|
"Olmo3ForCausalLM": "olmo",
|
|
"OlmoForCausalLM": "olmo",
|
|
"OlmoeForCausalLM": "olmo",
|
|
"OpenELMForCausalLM": "openelm",
|
|
"OrionForCausalLM": "orion",
|
|
"PLMForCausalLM": "plm",
|
|
"PLaMo2ForCausalLM": "plamo",
|
|
"PLaMo3ForCausalLM": "plamo",
|
|
"PaddleOCRVLForConditionalGeneration": "ernie",
|
|
"PanguEmbeddedForCausalLM": "pangu",
|
|
"Phi3ForCausalLM": "phi",
|
|
"Phi4ForCausalLMV": "phi",
|
|
"PhiForCausalLM": "phi",
|
|
"PhiMoEForCausalLM": "phi",
|
|
"Plamo2ForCausalLM": "plamo",
|
|
"Plamo3ForCausalLM": "plamo",
|
|
"PlamoForCausalLM": "plamo",
|
|
"QWenLMHeadModel": "qwen",
|
|
"Qwen2AudioForConditionalGeneration": "qwen",
|
|
"Qwen2ForCausalLM": "qwen",
|
|
"Qwen2Model": "qwen",
|
|
"Qwen2MoeForCausalLM": "qwen",
|
|
"Qwen2VLForConditionalGeneration": "qwenvl",
|
|
"Qwen2VLModel": "qwenvl",
|
|
"Qwen2_5OmniModel": "qwenvl",
|
|
"Qwen2_5_VLForConditionalGeneration": "qwenvl",
|
|
"Qwen3ASRForConditionalGeneration": "qwen3vl",
|
|
"Qwen3ForCausalLM": "qwen",
|
|
"Qwen3Model": "qwen",
|
|
"Qwen3MoeForCausalLM": "qwen",
|
|
"Qwen3NextForCausalLM": "qwen",
|
|
"Qwen3OmniMoeForConditionalGeneration": "qwen3vl",
|
|
"Qwen3VLForConditionalGeneration": "qwen3vl",
|
|
"Qwen3VLMoeForConditionalGeneration": "qwen3vl",
|
|
"Qwen3_5ForCausalLM": "qwen",
|
|
"Qwen3_5ForConditionalGeneration": "qwen",
|
|
"Qwen3_5MoeForCausalLM": "qwen",
|
|
"Qwen3_5MoeForConditionalGeneration": "qwen",
|
|
"RND1": "qwen",
|
|
"RWForCausalLM": "falcon",
|
|
"RWKV6Qwen2ForCausalLM": "rwkv",
|
|
"RWKV7ForCausalLM": "rwkv",
|
|
"RobertaForSequenceClassification": "bert",
|
|
"RobertaModel": "bert",
|
|
"RuGPT3XLForCausalLM": "gpt2",
|
|
"Rwkv6ForCausalLM": "rwkv",
|
|
"Rwkv7ForCausalLM": "rwkv",
|
|
"RwkvHybridForCausalLM": "rwkv",
|
|
"Sarashina2VisionForCausalLM": "sarashina2",
|
|
"SarvamMoEForCausalLM": "bailingmoe",
|
|
"SeedOssForCausalLM": "olmo",
|
|
"SmallThinkerForCausalLM": "smallthinker",
|
|
"SmolLM3ForCausalLM": "llama",
|
|
"SolarOpenForCausalLM": "glm",
|
|
"StableLMEpochForCausalLM": "stablelm",
|
|
"StableLmForCausalLM": "stablelm",
|
|
"Starcoder2ForCausalLM": "starcoder",
|
|
"Step3p5ForCausalLM": "step3",
|
|
"StepVLForConditionalGeneration": "step3",
|
|
"T5EncoderModel": "t5",
|
|
"T5ForConditionalGeneration": "t5",
|
|
"T5WithLMHeadModel": "t5",
|
|
"UMT5ForConditionalGeneration": "t5",
|
|
"UMT5Model": "t5",
|
|
"UltravoxModel": "ultravox",
|
|
"VLlama3ForCausalLM": "llama",
|
|
"VoxtralForConditionalGeneration": "llama",
|
|
"WavTokenizerDec": "wavtokenizer",
|
|
"XLMRobertaForSequenceClassification": "bert",
|
|
"XLMRobertaModel": "bert",
|
|
"XverseForCausalLM": "xverse",
|
|
"YoutuForCausalLM": "deepseek",
|
|
"YoutuVLForConditionalGeneration": "deepseek",
|
|
"modeling_grove_moe.GroveMoeForCausalLM": "grovemoe",
|
|
"modeling_sarvam_moe.SarvamMoEForCausalLM": "bailingmoe",
|
|
}
|
|
|
|
|
|
MMPROJ_MODEL_MAP: dict[str, str] = {
|
|
"AudioFlamingo3ForConditionalGeneration": "ultravox",
|
|
"CogVLMForCausalLM": "cogvlm",
|
|
"DeepseekOCRForCausalLM": "deepseek",
|
|
"DotsOCRForCausalLM": "dotsocr",
|
|
"Gemma3ForConditionalGeneration": "gemma",
|
|
"Gemma3nForConditionalGeneration": "gemma",
|
|
"Gemma4ForConditionalGeneration": "gemma",
|
|
"Glm4vForConditionalGeneration": "qwen3vl",
|
|
"Glm4vMoeForConditionalGeneration": "qwen3vl",
|
|
"GlmOcrForConditionalGeneration": "qwen3vl",
|
|
"GlmasrModel": "ultravox",
|
|
"GraniteSpeechForConditionalGeneration": "granite",
|
|
"HunYuanVLForConditionalGeneration": "hunyuan",
|
|
"Idefics3ForConditionalGeneration": "smolvlm",
|
|
"InternVisionModel": "internvl",
|
|
"JanusForConditionalGeneration": "januspro",
|
|
"KimiK25ForConditionalGeneration": "kimivl",
|
|
"KimiVLForConditionalGeneration": "kimivl",
|
|
"Lfm2AudioForConditionalGeneration": "lfm2",
|
|
"Lfm2VlForConditionalGeneration": "lfm2",
|
|
"LightOnOCRForConditionalGeneration": "lighton_ocr",
|
|
"Llama4ForConditionalGeneration": "llama4",
|
|
"LlavaForConditionalGeneration": "llava",
|
|
"MERaLiON2ForConditionalGeneration": "ultravox",
|
|
"MiMoV2ForCausalLM": "mimo",
|
|
"MiniCPMV4_6ForConditionalGeneration": "minicpm",
|
|
"Mistral3ForConditionalGeneration": "llava",
|
|
"NemotronH_Nano_VL_V2": "nemotron",
|
|
"PaddleOCRVisionModel": "ernie",
|
|
"Phi4ForCausalLMV": "phi",
|
|
"Qwen2AudioForConditionalGeneration": "ultravox",
|
|
"Qwen2VLForConditionalGeneration": "qwenvl",
|
|
"Qwen2VLModel": "qwenvl",
|
|
"Qwen2_5OmniModel": "qwenvl",
|
|
"Qwen2_5_VLForConditionalGeneration": "qwenvl",
|
|
"Qwen3ASRForConditionalGeneration": "qwen3vl",
|
|
"Qwen3OmniMoeForConditionalGeneration": "qwen3vl",
|
|
"Qwen3VLForConditionalGeneration": "qwen3vl",
|
|
"Qwen3VLMoeForConditionalGeneration": "qwen3vl",
|
|
"Qwen3_5ForConditionalGeneration": "qwen3vl",
|
|
"Qwen3_5MoeForConditionalGeneration": "qwen3vl",
|
|
"RADIOModel": "nemotron",
|
|
"Sarashina2VisionForCausalLM": "sarashina2",
|
|
"SmolVLMForConditionalGeneration": "smolvlm",
|
|
"StepVLForConditionalGeneration": "step3",
|
|
"UltravoxModel": "ultravox",
|
|
"VoxtralForConditionalGeneration": "ultravox",
|
|
"YoutuVLForConditionalGeneration": "youtuvl",
|
|
}
|
|
|
|
|
|
_TEXT_MODEL_MODULES = sorted(set(TEXT_MODEL_MAP.values()))
|
|
_MMPROJ_MODEL_MODULES = sorted(set(MMPROJ_MODEL_MAP.values()))
|
|
|
|
|
|
_loaded_text_modules: set[str] = set()
|
|
_loaded_mmproj_modules: set[str] = set()
|
|
|
|
|
|
def load_all_models() -> None:
|
|
"""Import all model modules to trigger @ModelBase.register() decorators."""
|
|
if len(_loaded_text_modules) != len(_TEXT_MODEL_MODULES):
|
|
for module_name in _TEXT_MODEL_MODULES:
|
|
if module_name not in _loaded_text_modules:
|
|
try:
|
|
__import__(f"conversion.{module_name}")
|
|
_loaded_text_modules.add(module_name)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to load model module {module_name}: {e}")
|
|
|
|
if len(_loaded_mmproj_modules) != len(_MMPROJ_MODEL_MODULES):
|
|
for module_name in _MMPROJ_MODEL_MODULES:
|
|
if module_name not in _loaded_mmproj_modules:
|
|
try:
|
|
__import__(f"conversion.{module_name}")
|
|
_loaded_mmproj_modules.add(module_name)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to load model module {module_name}: {e}")
|
|
|
|
|
|
def get_model_class(name: str, mmproj: bool = False) -> Type[ModelBase]:
|
|
"""Dynamically import and return a model class by its HuggingFace architecture name."""
|
|
relevant_map = MMPROJ_MODEL_MAP if mmproj else TEXT_MODEL_MAP
|
|
if name not in relevant_map:
|
|
raise NotImplementedError(f"Architecture {name!r} not supported!")
|
|
module_name = relevant_map[name]
|
|
__import__(f"conversion.{module_name}")
|
|
model_type = ModelType.MMPROJ if mmproj else ModelType.TEXT
|
|
return ModelBase._model_classes[model_type][name]
|
|
|
|
|
|
def print_registered_models() -> None:
|
|
load_all_models()
|
|
logger.error("TEXT models:")
|
|
for name in sorted(TEXT_MODEL_MAP.keys()):
|
|
logger.error(f" - {name}")
|
|
logger.error("MMPROJ models:")
|
|
for name in sorted(MMPROJ_MODEL_MAP.keys()):
|
|
logger.error(f" - {name}")
|