mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-11 17:44:38 +00:00
Merge commit 'f4586ee598
' into concedo_experimental
# Conflicts: # README.md # docs/multimodal/minicpmo2.6.md # docs/multimodal/minicpmv2.6.md # ggml/src/ggml-cann/aclnn_ops.cpp # ggml/src/ggml-cann/ggml-cann.cpp # ggml/src/ggml-cpu/kleidiai/kleidiai.cpp # ggml/src/ggml-cuda/CMakeLists.txt # ggml/src/ggml-opencl/ggml-opencl.cpp # ggml/src/ggml-opencl/kernels/add.cl # ggml/src/ggml-sycl/ggml-sycl.cpp # tools/perplexity/perplexity.cpp # tools/server/README.md
This commit is contained in:
commit
d5876024ec
17 changed files with 675 additions and 340 deletions
|
@ -2951,11 +2951,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||||
"- deepseek: puts thoughts in `message.reasoning_content` (except in streaming mode, which behaves as `none`)\n"
|
"- deepseek: puts thoughts in `message.reasoning_content` (except in streaming mode, which behaves as `none`)\n"
|
||||||
"(default: auto)",
|
"(default: auto)",
|
||||||
[](common_params & params, const std::string & value) {
|
[](common_params & params, const std::string & value) {
|
||||||
/**/ if (value == "deepseek") { params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; }
|
params.reasoning_format = common_reasoning_format_from_name(value);
|
||||||
else if (value == "deepseek-legacy") { params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY; }
|
|
||||||
else if (value == "none") { params.reasoning_format = COMMON_REASONING_FORMAT_NONE; }
|
|
||||||
else if (value == "auto") { params.reasoning_format = COMMON_REASONING_FORMAT_AUTO; }
|
|
||||||
else { throw std::invalid_argument("invalid value"); }
|
|
||||||
}
|
}
|
||||||
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_THINK"));
|
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_THINK"));
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
|
|
|
@ -552,6 +552,17 @@ common_chat_templates_ptr common_chat_templates_init(
|
||||||
default_template_src = CHATML_TEMPLATE_SRC;
|
default_template_src = CHATML_TEMPLATE_SRC;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO @ngxson : this is a temporary hack to prevent chat template from throwing an error
|
||||||
|
// Ref: https://github.com/ggml-org/llama.cpp/pull/15230#issuecomment-3173959633
|
||||||
|
if (default_template_src.find("<|channel|>") != std::string::npos
|
||||||
|
// search for the error message and patch it
|
||||||
|
&& default_template_src.find("in message.content or") != std::string::npos) {
|
||||||
|
string_replace_all(default_template_src,
|
||||||
|
"{%- if \"<|channel|>analysis<|message|>\" in message.content or \"<|channel|>final<|message|>\" in message.content %}",
|
||||||
|
"{%- if false %}");
|
||||||
|
}
|
||||||
|
|
||||||
std::string token_bos = bos_token_override;
|
std::string token_bos = bos_token_override;
|
||||||
std::string token_eos = eos_token_override;
|
std::string token_eos = eos_token_override;
|
||||||
bool add_bos = false;
|
bool add_bos = false;
|
||||||
|
@ -625,6 +636,19 @@ const char * common_reasoning_format_name(common_reasoning_format format) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
common_reasoning_format common_reasoning_format_from_name(const std::string & format) {
|
||||||
|
if (format == "none") {
|
||||||
|
return COMMON_REASONING_FORMAT_NONE;
|
||||||
|
} else if (format == "auto") {
|
||||||
|
return COMMON_REASONING_FORMAT_AUTO;
|
||||||
|
} else if (format == "deepseek") {
|
||||||
|
return COMMON_REASONING_FORMAT_DEEPSEEK;
|
||||||
|
} else if (format == "deepseek-legacy") {
|
||||||
|
return COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY;
|
||||||
|
}
|
||||||
|
throw std::runtime_error("Unknown reasoning format: " + format);
|
||||||
|
}
|
||||||
|
|
||||||
static std::string wrap_code_as_arguments(common_chat_msg_parser & builder, const std::string & code) {
|
static std::string wrap_code_as_arguments(common_chat_msg_parser & builder, const std::string & code) {
|
||||||
std::string arguments;
|
std::string arguments;
|
||||||
if (builder.is_partial()) {
|
if (builder.is_partial()) {
|
||||||
|
|
|
@ -191,6 +191,7 @@ std::string common_chat_format_example(
|
||||||
|
|
||||||
const char* common_chat_format_name(common_chat_format format);
|
const char* common_chat_format_name(common_chat_format format);
|
||||||
const char* common_reasoning_format_name(common_reasoning_format format);
|
const char* common_reasoning_format_name(common_reasoning_format format);
|
||||||
|
common_reasoning_format common_reasoning_format_from_name(const std::string & format);
|
||||||
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax);
|
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax);
|
||||||
|
|
||||||
common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice);
|
common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice);
|
||||||
|
|
|
@ -28,6 +28,14 @@ if TYPE_CHECKING:
|
||||||
if 'NO_LOCAL_GGUF' not in os.environ:
|
if 'NO_LOCAL_GGUF' not in os.environ:
|
||||||
sys.path.insert(1, str(Path(__file__).parent / 'gguf-py'))
|
sys.path.insert(1, str(Path(__file__).parent / 'gguf-py'))
|
||||||
import gguf
|
import gguf
|
||||||
|
from gguf.vocab import MistralTokenizerType, MistralVocab
|
||||||
|
from mistral_common.tokens.tokenizers.base import TokenizerVersion
|
||||||
|
from mistral_common.tokens.tokenizers.multimodal import DATASET_MEAN, DATASET_STD
|
||||||
|
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
|
||||||
|
from mistral_common.tokens.tokenizers.sentencepiece import (
|
||||||
|
SentencePieceTokenizer,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger("hf-to-gguf")
|
logger = logging.getLogger("hf-to-gguf")
|
||||||
|
|
||||||
|
@ -81,6 +89,8 @@ class ModelBase:
|
||||||
block_count: int
|
block_count: int
|
||||||
tensor_map: gguf.TensorNameMap
|
tensor_map: gguf.TensorNameMap
|
||||||
|
|
||||||
|
is_mistral_format: bool = False
|
||||||
|
|
||||||
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, *, is_big_endian: bool = False,
|
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, *, is_big_endian: bool = False,
|
||||||
use_temp_file: bool = False, eager: bool = False,
|
use_temp_file: bool = False, eager: bool = False,
|
||||||
metadata_override: Path | None = None, model_name: str | None = None,
|
metadata_override: Path | None = None, model_name: str | None = None,
|
||||||
|
@ -106,16 +116,17 @@ class ModelBase:
|
||||||
logger.info(f"Using remote model with HuggingFace id: {remote_hf_model_id}")
|
logger.info(f"Using remote model with HuggingFace id: {remote_hf_model_id}")
|
||||||
remote_tensors = gguf.utility.SafetensorRemote.get_list_tensors_hf_model(remote_hf_model_id)
|
remote_tensors = gguf.utility.SafetensorRemote.get_list_tensors_hf_model(remote_hf_model_id)
|
||||||
self.tensor_names = set(name for name in remote_tensors.keys())
|
self.tensor_names = set(name for name in remote_tensors.keys())
|
||||||
for name, remote_tensor in gguf.utility.SafetensorRemote.get_list_tensors_hf_model(remote_hf_model_id).items():
|
for name, remote_tensor in remote_tensors.items():
|
||||||
yield (name, LazyTorchTensor.from_remote_tensor(remote_tensor))
|
yield (name, LazyTorchTensor.from_remote_tensor(remote_tensor))
|
||||||
|
|
||||||
self.get_tensors = get_remote_tensors
|
self.get_tensors = get_remote_tensors
|
||||||
else:
|
else:
|
||||||
self.part_names = ModelBase.get_model_part_names(self.dir_model, "model", ".safetensors")
|
prefix = "model" if not self.is_mistral_format else "consolidated"
|
||||||
|
self.part_names = ModelBase.get_model_part_names(self.dir_model, prefix, ".safetensors")
|
||||||
self.is_safetensors = len(self.part_names) > 0
|
self.is_safetensors = len(self.part_names) > 0
|
||||||
if not self.is_safetensors:
|
if not self.is_safetensors:
|
||||||
self.part_names = ModelBase.get_model_part_names(self.dir_model, "pytorch_model", ".bin")
|
self.part_names = ModelBase.get_model_part_names(self.dir_model, "pytorch_model", ".bin")
|
||||||
self.hparams = ModelBase.load_hparams(self.dir_model) if hparams is None else hparams
|
self.hparams = ModelBase.load_hparams(self.dir_model, self.is_mistral_format) if hparams is None else hparams
|
||||||
self.tensor_names = None
|
self.tensor_names = None
|
||||||
self.metadata_override = metadata_override
|
self.metadata_override = metadata_override
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
|
@ -153,6 +164,7 @@ class ModelBase:
|
||||||
def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
|
def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
|
||||||
tensor_names_from_parts: set[str] = set()
|
tensor_names_from_parts: set[str] = set()
|
||||||
|
|
||||||
|
if not self.is_mistral_format:
|
||||||
index_name = "model.safetensors" if self.is_safetensors else "pytorch_model.bin"
|
index_name = "model.safetensors" if self.is_safetensors else "pytorch_model.bin"
|
||||||
index_name += ".index.json"
|
index_name += ".index.json"
|
||||||
index_file = self.dir_model / index_name
|
index_file = self.dir_model / index_name
|
||||||
|
@ -169,6 +181,9 @@ class ModelBase:
|
||||||
else:
|
else:
|
||||||
self.tensor_names = tensor_names_from_parts
|
self.tensor_names = tensor_names_from_parts
|
||||||
weight_map = {}
|
weight_map = {}
|
||||||
|
else:
|
||||||
|
self.tensor_names = tensor_names_from_parts
|
||||||
|
weight_map = {}
|
||||||
|
|
||||||
for part_name in self.part_names:
|
for part_name in self.part_names:
|
||||||
logger.info(f"gguf: loading model part '{part_name}'")
|
logger.info(f"gguf: loading model part '{part_name}'")
|
||||||
|
@ -426,7 +441,12 @@ class ModelBase:
|
||||||
return part_names
|
return part_names
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load_hparams(dir_model: Path):
|
def load_hparams(dir_model: Path, is_mistral_format: bool):
|
||||||
|
if is_mistral_format:
|
||||||
|
with open(dir_model / "params.json", "r", encoding="utf-8") as f:
|
||||||
|
config = json.load(f)
|
||||||
|
return config
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# for security reason, we don't allow loading remote code by default
|
# for security reason, we don't allow loading remote code by default
|
||||||
# if a model need remote code, we will fallback to config.json
|
# if a model need remote code, we will fallback to config.json
|
||||||
|
@ -476,7 +496,10 @@ class TextModel(ModelBase):
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
if not self.is_mistral_format:
|
||||||
self.hf_arch = get_model_architecture(self.hparams, self.model_type)
|
self.hf_arch = get_model_architecture(self.hparams, self.model_type)
|
||||||
|
else:
|
||||||
|
self.hf_arch = ""
|
||||||
|
|
||||||
if "text_config" in self.hparams:
|
if "text_config" in self.hparams:
|
||||||
# move the text_config to the root level
|
# move the text_config to the root level
|
||||||
|
@ -542,14 +565,14 @@ class TextModel(ModelBase):
|
||||||
self.gguf_writer.add_head_count(n_head)
|
self.gguf_writer.add_head_count(n_head)
|
||||||
logger.info(f"gguf: head count = {n_head}")
|
logger.info(f"gguf: head count = {n_head}")
|
||||||
|
|
||||||
if (n_head_kv := self.hparams.get("num_key_value_heads")) is not None:
|
if (n_head_kv := self.find_hparam(["num_key_value_heads", "n_kv_heads"], optional=True)) is not None:
|
||||||
self.gguf_writer.add_head_count_kv(n_head_kv)
|
self.gguf_writer.add_head_count_kv(n_head_kv)
|
||||||
logger.info(f"gguf: key-value head count = {n_head_kv}")
|
logger.info(f"gguf: key-value head count = {n_head_kv}")
|
||||||
|
|
||||||
if (rope_theta := self.hparams.get("rope_theta")) is not None:
|
if (rope_theta := self.hparams.get("rope_theta")) is not None:
|
||||||
self.gguf_writer.add_rope_freq_base(rope_theta)
|
self.gguf_writer.add_rope_freq_base(rope_theta)
|
||||||
logger.info(f"gguf: rope theta = {rope_theta}")
|
logger.info(f"gguf: rope theta = {rope_theta}")
|
||||||
if (f_rms_eps := self.hparams.get("rms_norm_eps")) is not None:
|
if (f_rms_eps := self.find_hparam(["rms_norm_eps", "norm_eps"], optional=True)) is not None:
|
||||||
self.gguf_writer.add_layer_norm_rms_eps(f_rms_eps)
|
self.gguf_writer.add_layer_norm_rms_eps(f_rms_eps)
|
||||||
logger.info(f"gguf: rms norm epsilon = {f_rms_eps}")
|
logger.info(f"gguf: rms norm epsilon = {f_rms_eps}")
|
||||||
if (f_norm_eps := self.find_hparam(["layer_norm_eps", "layer_norm_epsilon", "norm_epsilon"], optional=True)) is not None:
|
if (f_norm_eps := self.find_hparam(["layer_norm_eps", "layer_norm_epsilon", "norm_epsilon"], optional=True)) is not None:
|
||||||
|
@ -1210,12 +1233,19 @@ class MmprojModel(ModelBase):
|
||||||
raise TypeError("MmprojModel must be subclassed with model_arch = gguf.MODEL_ARCH.MMPROJ")
|
raise TypeError("MmprojModel must be subclassed with model_arch = gguf.MODEL_ARCH.MMPROJ")
|
||||||
|
|
||||||
# get n_embd of the text model
|
# get n_embd of the text model
|
||||||
|
if not self.is_mistral_format:
|
||||||
if "text_config" not in self.hparams:
|
if "text_config" not in self.hparams:
|
||||||
self.hparams["text_config"] = {}
|
self.hparams["text_config"] = {}
|
||||||
if "audio_config" not in self.hparams:
|
if "audio_config" not in self.hparams:
|
||||||
self.hparams["audio_config"] = {}
|
self.hparams["audio_config"] = {}
|
||||||
text_config = {**self.hparams, **self.hparams["text_config"]}
|
text_config = {**self.hparams, **self.hparams["text_config"]}
|
||||||
self.n_embd_text = text_config.get("hidden_size", text_config.get("n_embd", 0))
|
self.n_embd_text = text_config.get("hidden_size", text_config.get("n_embd", 0))
|
||||||
|
else:
|
||||||
|
text_config = {
|
||||||
|
k: v for k, v in self.hparams.items() if k not in ["vision_encoder", "audio_encoder"]
|
||||||
|
}
|
||||||
|
self.n_embd_text = text_config.get("hidden_dim", 0)
|
||||||
|
|
||||||
assert self.n_embd_text > 0, "n_embd not found in hparams"
|
assert self.n_embd_text > 0, "n_embd not found in hparams"
|
||||||
|
|
||||||
# move vision config to the top level, while preserving the original hparams in global_config
|
# move vision config to the top level, while preserving the original hparams in global_config
|
||||||
|
@ -1236,11 +1266,13 @@ class MmprojModel(ModelBase):
|
||||||
self.tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.MMPROJ, self.block_count)
|
self.tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.MMPROJ, self.block_count)
|
||||||
|
|
||||||
# load preprocessor config
|
# load preprocessor config
|
||||||
|
if not self.is_mistral_format:
|
||||||
with open(self.dir_model / "preprocessor_config.json", "r", encoding="utf-8") as f:
|
with open(self.dir_model / "preprocessor_config.json", "r", encoding="utf-8") as f:
|
||||||
self.preprocessor_config = json.load(f)
|
self.preprocessor_config = json.load(f)
|
||||||
|
|
||||||
def get_vision_config(self) -> dict[str, Any] | None:
|
def get_vision_config(self) -> dict[str, Any] | None:
|
||||||
return self.global_config.get("vision_config")
|
config_name = "vision_config" if not self.is_mistral_format else "vision_encoder"
|
||||||
|
return self.global_config.get(config_name)
|
||||||
|
|
||||||
def get_audio_config(self) -> dict[str, Any] | None:
|
def get_audio_config(self) -> dict[str, Any] | None:
|
||||||
return self.global_config.get("audio_config")
|
return self.global_config.get("audio_config")
|
||||||
|
@ -1264,8 +1296,11 @@ class MmprojModel(ModelBase):
|
||||||
self.gguf_writer.add_vision_head_count(self.find_vparam(["num_attention_heads"]))
|
self.gguf_writer.add_vision_head_count(self.find_vparam(["num_attention_heads"]))
|
||||||
|
|
||||||
# preprocessor config
|
# preprocessor config
|
||||||
self.gguf_writer.add_vision_image_mean(self.preprocessor_config["image_mean"])
|
image_mean = DATASET_MEAN if self.is_mistral_format else self.preprocessor_config["image_mean"]
|
||||||
self.gguf_writer.add_vision_image_std(self.preprocessor_config["image_std"])
|
image_std = DATASET_STD if self.is_mistral_format else self.preprocessor_config["image_std"]
|
||||||
|
|
||||||
|
self.gguf_writer.add_vision_image_mean(image_mean)
|
||||||
|
self.gguf_writer.add_vision_image_std(image_std)
|
||||||
|
|
||||||
if self.has_audio_encoder:
|
if self.has_audio_encoder:
|
||||||
self.gguf_writer.add_clip_has_audio_encoder(True)
|
self.gguf_writer.add_clip_has_audio_encoder(True)
|
||||||
|
@ -1924,11 +1959,63 @@ class LlamaModel(TextModel):
|
||||||
if self.hf_arch == "VLlama3ForCausalLM":
|
if self.hf_arch == "VLlama3ForCausalLM":
|
||||||
self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 32)
|
self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 32)
|
||||||
|
|
||||||
|
def _set_vocab_mistral(self):
|
||||||
|
vocab = MistralVocab(self.dir_model)
|
||||||
|
logger.info(
|
||||||
|
f"Converting tokenizer {vocab.tokenizer_type} of size {vocab.vocab_size}."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.gguf_writer.add_tokenizer_model(vocab.gguf_tokenizer_model)
|
||||||
|
|
||||||
|
tokens = []
|
||||||
|
scores = []
|
||||||
|
toktypes = []
|
||||||
|
|
||||||
|
for text, score, toktype in vocab.all_tokens():
|
||||||
|
tokens.append(text)
|
||||||
|
scores.append(score)
|
||||||
|
toktypes.append(toktype)
|
||||||
|
|
||||||
|
assert len(tokens) == vocab.vocab_size, (
|
||||||
|
f"token count ({len(tokens)}) != vocab size ({vocab.vocab_size})"
|
||||||
|
)
|
||||||
|
|
||||||
|
if vocab.tokenizer_type == MistralTokenizerType.tekken:
|
||||||
|
self.gguf_writer.add_tokenizer_pre("tekken")
|
||||||
|
self.gguf_writer.add_token_merges(
|
||||||
|
vocab.extract_vocab_merges_from_model()
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Setting bos, eos, unk and pad token IDs to {vocab.bos_id}, {vocab.eos_id}, {vocab.unk_id}, {vocab.pad_id}."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.gguf_writer.add_bos_token_id(vocab.bos_id)
|
||||||
|
self.gguf_writer.add_eos_token_id(vocab.eos_id)
|
||||||
|
self.gguf_writer.add_unk_token_id(vocab.unk_id)
|
||||||
|
self.gguf_writer.add_pad_token_id(vocab.pad_id)
|
||||||
|
|
||||||
|
self.gguf_writer.add_token_list(tokens)
|
||||||
|
self.gguf_writer.add_token_scores(scores)
|
||||||
|
self.gguf_writer.add_token_types(toktypes)
|
||||||
|
self.gguf_writer.add_vocab_size(vocab.vocab_size)
|
||||||
|
|
||||||
|
self.gguf_writer.add_add_bos_token(True)
|
||||||
|
self.gguf_writer.add_add_eos_token(False)
|
||||||
|
|
||||||
|
template_dir = Path(__file__).parent / "models/templates/"
|
||||||
|
|
||||||
|
template = MistralModel.get_community_chat_template(vocab, template_dir)
|
||||||
|
self.gguf_writer.add_chat_template(template)
|
||||||
|
|
||||||
def set_vocab(self):
|
def set_vocab(self):
|
||||||
|
if self.is_mistral_format:
|
||||||
|
return self._set_vocab_mistral()
|
||||||
|
|
||||||
path_tekken_json = self.dir_model / "tekken.json"
|
path_tekken_json = self.dir_model / "tekken.json"
|
||||||
path_tokenizer_json = self.dir_model / "tokenizer.json"
|
path_tokenizer_json = self.dir_model / "tokenizer.json"
|
||||||
if path_tekken_json.is_file() and not path_tokenizer_json.is_file():
|
if path_tekken_json.is_file() and not path_tokenizer_json.is_file():
|
||||||
return self.set_vocab_tekken()
|
self._set_vocab_mistral()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self._set_vocab_sentencepiece()
|
self._set_vocab_sentencepiece()
|
||||||
|
@ -1962,55 +2049,11 @@ class LlamaModel(TextModel):
|
||||||
if self.hparams.get("vocab_size", 32000) == 49152:
|
if self.hparams.get("vocab_size", 32000) == 49152:
|
||||||
self.gguf_writer.add_add_bos_token(False)
|
self.gguf_writer.add_add_bos_token(False)
|
||||||
|
|
||||||
def set_vocab_tekken(self):
|
|
||||||
vocab = gguf.vocab.MistralVocab(self.dir_model)
|
|
||||||
self.gguf_writer.add_tokenizer_model(vocab.gguf_tokenizer_model)
|
|
||||||
|
|
||||||
tokens = []
|
|
||||||
scores = []
|
|
||||||
toktypes = []
|
|
||||||
|
|
||||||
for text, score, toktype in vocab.all_tokens():
|
|
||||||
tokens.append(text)
|
|
||||||
scores.append(score)
|
|
||||||
toktypes.append(toktype)
|
|
||||||
|
|
||||||
assert len(tokens) == vocab.vocab_size, (
|
|
||||||
f"token count ({len(tokens)}) != vocab size ({vocab.vocab_size})"
|
|
||||||
)
|
|
||||||
|
|
||||||
if vocab.tokenizer_type == gguf.vocab.MistralTokenizerType.tekken:
|
|
||||||
self.gguf_writer.add_tokenizer_pre("tekken")
|
|
||||||
self.gguf_writer.add_token_merges(
|
|
||||||
vocab.extract_vocab_merges_from_model()
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Setting bos, eos, unk and pad token IDs to {vocab.bos_id}, {vocab.eos_id}, {vocab.unk_id}, {vocab.pad_id}."
|
|
||||||
)
|
|
||||||
|
|
||||||
self.gguf_writer.add_bos_token_id(vocab.bos_id)
|
|
||||||
self.gguf_writer.add_eos_token_id(vocab.eos_id)
|
|
||||||
self.gguf_writer.add_unk_token_id(vocab.unk_id)
|
|
||||||
self.gguf_writer.add_pad_token_id(vocab.pad_id)
|
|
||||||
|
|
||||||
self.gguf_writer.add_token_list(tokens)
|
|
||||||
self.gguf_writer.add_token_scores(scores)
|
|
||||||
self.gguf_writer.add_token_types(toktypes)
|
|
||||||
self.gguf_writer.add_vocab_size(vocab.vocab_size)
|
|
||||||
|
|
||||||
self.gguf_writer.add_add_bos_token(True)
|
|
||||||
self.gguf_writer.add_add_eos_token(False)
|
|
||||||
|
|
||||||
script_dir = Path(__file__).parent
|
|
||||||
template_path = script_dir / "models/templates/unsloth-mistral-Devstral-Small-2507.jinja"
|
|
||||||
with open(template_path, "r", encoding="utf-8") as f:
|
|
||||||
template = f.read()
|
|
||||||
self.gguf_writer.add_chat_template(template)
|
|
||||||
|
|
||||||
def set_gguf_parameters(self):
|
def set_gguf_parameters(self):
|
||||||
super().set_gguf_parameters()
|
super().set_gguf_parameters()
|
||||||
hparams = self.hparams
|
hparams = self.hparams
|
||||||
|
|
||||||
|
if not self.is_mistral_format:
|
||||||
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
|
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
|
||||||
|
|
||||||
if (rope_dim := hparams.get("head_dim")) is None:
|
if (rope_dim := hparams.get("head_dim")) is None:
|
||||||
|
@ -2033,13 +2076,25 @@ class LlamaModel(TextModel):
|
||||||
_experts: list[dict[str, Tensor]] | None = None
|
_experts: list[dict[str, Tensor]] | None = None
|
||||||
|
|
||||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||||
n_head = self.hparams["num_attention_heads"]
|
n_head = self.find_hparam(["n_heads", "num_attention_heads"])
|
||||||
n_kv_head = self.hparams.get("num_key_value_heads")
|
n_kv_head = self.find_hparam(["n_kv_heads", "num_key_value_heads"])
|
||||||
|
|
||||||
|
vision_prefixes = [
|
||||||
|
"vision_encoder.",
|
||||||
|
"vision_language_adapter.",
|
||||||
|
"patch_merger.",
|
||||||
|
"pre_mm_projector_norm",
|
||||||
|
]
|
||||||
|
|
||||||
is_multimodal_tensor = "vision_tower" in name \
|
is_multimodal_tensor = "vision_tower" in name \
|
||||||
or "vision_model" in name \
|
or "vision_model" in name \
|
||||||
or "audio_tower" in name \
|
or "audio_tower" in name \
|
||||||
or "model.connector" in name \
|
or "model.connector" in name \
|
||||||
or "multi_modal_projector" in name
|
or "multi_modal_projector" in name \
|
||||||
|
or any(
|
||||||
|
name.startswith(prefix)
|
||||||
|
for prefix in vision_prefixes
|
||||||
|
)
|
||||||
|
|
||||||
if is_multimodal_tensor:
|
if is_multimodal_tensor:
|
||||||
return [] # skip vision tensors
|
return [] # skip vision tensors
|
||||||
|
@ -2155,13 +2210,18 @@ class LlavaVisionModel(MmprojModel):
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
if self.hparams["model_type"] == "pixtral":
|
if self.hparams.get("model_type") == "pixtral":
|
||||||
# layer_norm_eps is not in config.json, it is hard-coded in modeling_pixtral.py
|
# layer_norm_eps is not in config.json, it is hard-coded in modeling_pixtral.py
|
||||||
self.hparams["layer_norm_eps"] = self.hparams.get("layer_norm_eps", 1e-5)
|
self.hparams["layer_norm_eps"] = self.hparams.get("layer_norm_eps", 1e-5)
|
||||||
self.img_break_tok_id = self.get_token_id("[IMG_BREAK]")
|
self.img_break_tok_id = self.get_token_id("[IMG_BREAK]")
|
||||||
logger.info(f"Image break token id: {self.img_break_tok_id}")
|
elif self.is_mistral_format:
|
||||||
|
# hparams is already vision config here so norm_eps is only defined in global_config.
|
||||||
|
self.hparams["norm_eps"] = self.global_config.get("norm_eps", None)
|
||||||
|
assert self.hparams["norm_eps"] is not None, "norm_eps not found in params.json"
|
||||||
|
self.img_break_tok_id = self.find_vparam(["image_break_token_id"])
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported model type: {self.hparams['model_type']}")
|
raise ValueError(f"Unsupported model type: {self.hparams['model_type']}")
|
||||||
|
logger.info(f"Image break token id: {self.img_break_tok_id}")
|
||||||
|
|
||||||
def get_token_id(self, token: str) -> int:
|
def get_token_id(self, token: str) -> int:
|
||||||
tokenizer_config_file = self.dir_model / 'tokenizer_config.json'
|
tokenizer_config_file = self.dir_model / 'tokenizer_config.json'
|
||||||
|
@ -2175,7 +2235,7 @@ class LlavaVisionModel(MmprojModel):
|
||||||
def set_gguf_parameters(self):
|
def set_gguf_parameters(self):
|
||||||
super().set_gguf_parameters()
|
super().set_gguf_parameters()
|
||||||
hparams = self.hparams
|
hparams = self.hparams
|
||||||
if hparams["model_type"] == "pixtral":
|
if hparams.get("model_type") == "pixtral":
|
||||||
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.PIXTRAL)
|
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.PIXTRAL)
|
||||||
self.gguf_writer.add_vision_attention_layernorm_eps(hparams["layer_norm_eps"])
|
self.gguf_writer.add_vision_attention_layernorm_eps(hparams["layer_norm_eps"])
|
||||||
|
|
||||||
|
@ -2193,18 +2253,30 @@ class LlavaVisionModel(MmprojModel):
|
||||||
|
|
||||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||||
del bid # unused
|
del bid # unused
|
||||||
n_head = self.hparams["num_attention_heads"]
|
n_head = (
|
||||||
|
self.hparams["num_attention_heads"] if not self.is_mistral_format else self.find_vparam(["num_attention_heads"])
|
||||||
|
)
|
||||||
n_kv_head = n_head
|
n_kv_head = n_head
|
||||||
|
|
||||||
if name.startswith("multi_modal_projector.") or name.startswith("vision_tower."):
|
valid_prefixes = (
|
||||||
|
"multi_modal_projector.",
|
||||||
|
"vision_tower.",
|
||||||
|
"vision_encoder.",
|
||||||
|
"vision_language_adapter.",
|
||||||
|
"patch_merger.",
|
||||||
|
"pre_mm_projector_norm",
|
||||||
|
)
|
||||||
|
|
||||||
|
if any(name.startswith(prefix) for prefix in valid_prefixes):
|
||||||
# process vision tensors
|
# process vision tensors
|
||||||
if name.endswith(("q_proj.weight", "q_proj.bias")):
|
if name.endswith(("q_proj.weight", "q_proj.bias")) and not self.is_mistral_format:
|
||||||
data_torch = LlamaModel.permute(data_torch, n_head, n_head)
|
data_torch = LlamaModel.permute(data_torch, n_head, n_head)
|
||||||
if name.endswith(("k_proj.weight", "k_proj.bias")):
|
if name.endswith(("k_proj.weight", "k_proj.bias")) and not self.is_mistral_format:
|
||||||
data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head)
|
data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head)
|
||||||
return [(self.map_tensor_name(name), data_torch)]
|
return [(self.map_tensor_name(name), data_torch)]
|
||||||
|
|
||||||
if self.img_break_tok_id > 0 and "embed_tokens.weight" in name:
|
embed_key = "embed_tokens.weight" if not self.is_mistral_format else "tok_embeddings.weight"
|
||||||
|
if self.img_break_tok_id > 0 and embed_key in name:
|
||||||
logger.info(f"Extracting [IMG_BREAK] token embedding from {name}")
|
logger.info(f"Extracting [IMG_BREAK] token embedding from {name}")
|
||||||
# for pixtral model, we need to extract the [IMG_BREAK] token embedding
|
# for pixtral model, we need to extract the [IMG_BREAK] token embedding
|
||||||
img_break_embd = data_torch[self.img_break_tok_id]
|
img_break_embd = data_torch[self.img_break_tok_id]
|
||||||
|
@ -3526,7 +3598,7 @@ class Qwen3MoeModel(Qwen2MoeModel):
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
hparams = ModelBase.load_hparams(self.dir_model)
|
hparams = ModelBase.load_hparams(self.dir_model, False)
|
||||||
self.origin_hf_arch = hparams.get('architectures', [None])[0]
|
self.origin_hf_arch = hparams.get('architectures', [None])[0]
|
||||||
|
|
||||||
def set_vocab(self):
|
def set_vocab(self):
|
||||||
|
@ -4683,7 +4755,7 @@ class NomicBertModel(BertModel):
|
||||||
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, **kwargs: Any):
|
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, **kwargs: Any):
|
||||||
hparams = kwargs.pop("hparams", None)
|
hparams = kwargs.pop("hparams", None)
|
||||||
if hparams is None:
|
if hparams is None:
|
||||||
hparams = ModelBase.load_hparams(dir_model)
|
hparams = ModelBase.load_hparams(dir_model, False)
|
||||||
|
|
||||||
self.is_moe = bool(hparams.get("moe_every_n_layers"))
|
self.is_moe = bool(hparams.get("moe_every_n_layers"))
|
||||||
self.model_arch = gguf.MODEL_ARCH.NOMIC_BERT_MOE if self.is_moe else gguf.MODEL_ARCH.NOMIC_BERT
|
self.model_arch = gguf.MODEL_ARCH.NOMIC_BERT_MOE if self.is_moe else gguf.MODEL_ARCH.NOMIC_BERT
|
||||||
|
@ -8304,6 +8376,77 @@ class SmallThinkerModel(TextModel):
|
||||||
if len(experts) > 0:
|
if len(experts) > 0:
|
||||||
raise ValueError(f"Unprocessed experts: {experts}")
|
raise ValueError(f"Unprocessed experts: {experts}")
|
||||||
|
|
||||||
|
|
||||||
|
class MistralModel(LlamaModel):
|
||||||
|
model_arch = gguf.MODEL_ARCH.LLAMA
|
||||||
|
model_name = "Mistral"
|
||||||
|
hf_arch = ""
|
||||||
|
is_mistral_format = True
|
||||||
|
undo_permute = False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_community_chat_template(vocab: MistralVocab, templates_dir: Path):
|
||||||
|
assert TokenizerVersion is not None, "mistral_common is not installed"
|
||||||
|
assert isinstance(vocab.tokenizer, (Tekkenizer, SentencePieceTokenizer)), (
|
||||||
|
f"Expected Tekkenizer or SentencePieceTokenizer, got {type(vocab.tokenizer)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if vocab.tokenizer.version == TokenizerVersion.v1:
|
||||||
|
return "mistral-v1"
|
||||||
|
elif vocab.tokenizer.version == TokenizerVersion.v3 and vocab.tokenizer_type == MistralTokenizerType.spm:
|
||||||
|
return "mistral-v3"
|
||||||
|
elif vocab.tokenizer.version == TokenizerVersion.v3 and vocab.tokenizer_type == MistralTokenizerType.tekken:
|
||||||
|
return "mistral-v3-tekken"
|
||||||
|
elif vocab.tokenizer.version == TokenizerVersion.v7 and vocab.tokenizer_type == MistralTokenizerType.spm:
|
||||||
|
return "mistral-v7"
|
||||||
|
elif vocab.tokenizer.version == TokenizerVersion.v7 and vocab.tokenizer_type == MistralTokenizerType.tekken:
|
||||||
|
return "mistral-v7-tekken"
|
||||||
|
elif vocab.tokenizer.version == TokenizerVersion.v11:
|
||||||
|
template_file = "Mistral-Small-3.2-24B-Instruct-2506.jinja"
|
||||||
|
elif vocab.tokenizer.version == TokenizerVersion.v13:
|
||||||
|
template_file = "unsloth-mistral-Devstral-Small-2507.jinja"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown tokenizer type: {vocab.tokenizer_type} and version {vocab.tokenizer.version}")
|
||||||
|
|
||||||
|
template_path = templates_dir / template_file
|
||||||
|
if not template_path.exists():
|
||||||
|
raise FileNotFoundError(f"Template file not found: {template_path}")
|
||||||
|
|
||||||
|
with open(template_path, "r", encoding="utf-8") as f:
|
||||||
|
template = f.read()
|
||||||
|
|
||||||
|
return template
|
||||||
|
|
||||||
|
|
||||||
|
class PixtralModel(LlavaVisionModel):
|
||||||
|
model_name = "Pixtral"
|
||||||
|
hf_arch = ""
|
||||||
|
is_mistral_format = True
|
||||||
|
|
||||||
|
def set_gguf_parameters(self):
|
||||||
|
super().set_gguf_parameters()
|
||||||
|
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.PIXTRAL)
|
||||||
|
|
||||||
|
self.gguf_writer.add_vision_attention_layernorm_eps(
|
||||||
|
self.find_hparam(["norm_eps"])
|
||||||
|
)
|
||||||
|
self.gguf_writer.add_rope_freq_base(self.find_vparam(["rope_theta"]))
|
||||||
|
|
||||||
|
self.gguf_writer.add_vision_use_silu(True)
|
||||||
|
|
||||||
|
# spatial_merge_size
|
||||||
|
if self.find_vparam(["mm_projector_id"]) == "patch_merge":
|
||||||
|
self.gguf_writer.add_vision_spatial_merge_size(
|
||||||
|
self.find_vparam(["spatial_merge_size"])
|
||||||
|
)
|
||||||
|
|
||||||
|
def map_tensor_name(self, name: str, try_suffixes: Sequence[str] = (".weight", ".bias")) -> str:
|
||||||
|
if name == "vision_language_adapter.w_in.weight":
|
||||||
|
return "mm.1.weight"
|
||||||
|
elif name == "vision_language_adapter.w_out.weight":
|
||||||
|
return "mm.2.weight"
|
||||||
|
return super().map_tensor_name(name, try_suffixes)
|
||||||
|
|
||||||
###### CONVERSION LOGIC ######
|
###### CONVERSION LOGIC ######
|
||||||
|
|
||||||
|
|
||||||
|
@ -8454,6 +8597,10 @@ def parse_args() -> argparse.Namespace:
|
||||||
"--mmproj", action="store_true",
|
"--mmproj", action="store_true",
|
||||||
help="(Experimental) Export multimodal projector (mmproj) for vision models. This will only work on some vision models. A prefix 'mmproj-' will be added to the output file name.",
|
help="(Experimental) Export multimodal projector (mmproj) for vision models. This will only work on some vision models. A prefix 'mmproj-' will be added to the output file name.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--mistral-format", action="store_true",
|
||||||
|
help="Whether the model is stored following the Mistral format.",
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
if not args.print_supported_models and args.model is None:
|
if not args.print_supported_models and args.model is None:
|
||||||
|
@ -8559,10 +8706,13 @@ def main() -> None:
|
||||||
if "mmproj" not in fname_out.name:
|
if "mmproj" not in fname_out.name:
|
||||||
fname_out = ModelBase.add_prefix_to_filename(fname_out, "mmproj-")
|
fname_out = ModelBase.add_prefix_to_filename(fname_out, "mmproj-")
|
||||||
|
|
||||||
|
is_mistral_format = args.mistral_format
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
output_type = ftype_map[args.outtype]
|
output_type = ftype_map[args.outtype]
|
||||||
model_type = ModelType.MMPROJ if args.mmproj else ModelType.TEXT
|
model_type = ModelType.MMPROJ if args.mmproj else ModelType.TEXT
|
||||||
hparams = ModelBase.load_hparams(dir_model)
|
hparams = ModelBase.load_hparams(dir_model, is_mistral_format)
|
||||||
|
if not is_mistral_format:
|
||||||
model_architecture = get_model_architecture(hparams, model_type)
|
model_architecture = get_model_architecture(hparams, model_type)
|
||||||
logger.info(f"Model architecture: {model_architecture}")
|
logger.info(f"Model architecture: {model_architecture}")
|
||||||
try:
|
try:
|
||||||
|
@ -8570,6 +8720,11 @@ def main() -> None:
|
||||||
except NotImplementedError:
|
except NotImplementedError:
|
||||||
logger.error(f"Model {model_architecture} is not supported")
|
logger.error(f"Model {model_architecture} is not supported")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
elif args.mmproj:
|
||||||
|
assert hparams.get("vision_encoder") is not None, "This model does not support multimodal"
|
||||||
|
model_class = PixtralModel
|
||||||
|
else:
|
||||||
|
model_class = MistralModel
|
||||||
|
|
||||||
model_instance = model_class(dir_model, output_type, fname_out,
|
model_instance = model_class(dir_model, output_type, fname_out,
|
||||||
is_big_endian=args.bigendian, use_temp_file=args.use_temp_file,
|
is_big_endian=args.bigendian, use_temp_file=args.use_temp_file,
|
||||||
|
@ -8578,7 +8733,8 @@ def main() -> None:
|
||||||
split_max_tensors=args.split_max_tensors,
|
split_max_tensors=args.split_max_tensors,
|
||||||
split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run,
|
split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run,
|
||||||
small_first_shard=args.no_tensor_first_split,
|
small_first_shard=args.no_tensor_first_split,
|
||||||
remote_hf_model_id=hf_repo_id)
|
remote_hf_model_id=hf_repo_id,
|
||||||
|
)
|
||||||
|
|
||||||
if args.vocab_only:
|
if args.vocab_only:
|
||||||
logger.info("Exporting model vocab...")
|
logger.info("Exporting model vocab...")
|
||||||
|
|
|
@ -340,7 +340,7 @@ if __name__ == '__main__':
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
else:
|
else:
|
||||||
logger.info(f"Loading base model: {dir_base_model.name}")
|
logger.info(f"Loading base model: {dir_base_model.name}")
|
||||||
hparams = ModelBase.load_hparams(dir_base_model)
|
hparams = ModelBase.load_hparams(dir_base_model, False)
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -316,11 +316,11 @@ static bool turing_mma_available(const int cc) {
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool ampere_mma_available(const int cc) {
|
static bool ampere_mma_available(const int cc) {
|
||||||
return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE;
|
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE;
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool cp_async_available(const int cc) {
|
static bool cp_async_available(const int cc) {
|
||||||
return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE;
|
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE;
|
||||||
}
|
}
|
||||||
|
|
||||||
static constexpr __device__ int ggml_cuda_get_physical_warp_size() {
|
static constexpr __device__ int ggml_cuda_get_physical_warp_size() {
|
||||||
|
|
|
@ -1,87 +1,117 @@
|
||||||
|
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
|
||||||
|
#define USE_CUB
|
||||||
|
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
|
||||||
|
|
||||||
|
#ifdef USE_CUB
|
||||||
|
#include <cub/cub.cuh>
|
||||||
|
using namespace cub;
|
||||||
|
#endif // USE_CUB
|
||||||
|
|
||||||
#include "ssm-scan.cuh"
|
#include "ssm-scan.cuh"
|
||||||
|
|
||||||
template <size_t splitD, size_t N>
|
// We would like to keep pragma unroll for cases where L_template is not 0,
|
||||||
__global__ void __launch_bounds__(splitD, 2)
|
// so we suppress the clang transformation warning.
|
||||||
ssm_scan_f32(const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2,
|
#ifdef __clang__
|
||||||
const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5,
|
#pragma clang diagnostic push
|
||||||
|
#pragma clang diagnostic ignored "-Wpass-failed"
|
||||||
|
#endif // __clang__
|
||||||
|
template <size_t splitD, size_t N, size_t L_template>
|
||||||
|
__global__ void __launch_bounds__(splitD, 1)
|
||||||
|
ssm_scan_f32(const float *__restrict__ src0, const float *__restrict__ src1, const float *__restrict__ src2,
|
||||||
|
const float *__restrict__ src3, const float *__restrict__ src4, const float *__restrict__ src5,
|
||||||
const int32_t * __restrict__ src6, float * __restrict__ dst,
|
const int32_t * __restrict__ src6, float * __restrict__ dst,
|
||||||
const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3,
|
const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3,
|
||||||
const int src2_nb1, const int src2_nb2, const int src3_nb1,
|
const int src2_nb1, const int src2_nb2, const int src3_nb1,
|
||||||
const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3,
|
const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3,
|
||||||
const int64_t s_off, const int64_t d_inner, const int64_t L) {
|
const int64_t s_off, const int64_t d_inner, const int64_t L_param)
|
||||||
|
{
|
||||||
|
const size_t L = L_template == 0 ? L_param : L_template;
|
||||||
|
const float *s0_block = (const float *)((const char *)src0 + src6[blockIdx.x] * src0_nb3 + blockIdx.y * splitD * src0_nb2);
|
||||||
|
const float *x_block = (const float *)((const char *)src1 + (blockIdx.x * src1_nb3) + blockIdx.y * splitD * sizeof(float));
|
||||||
|
const float *dt_block = (const float *)((const char *)src2 + (blockIdx.x * src2_nb2) + blockIdx.y * splitD * sizeof(float));
|
||||||
|
const float *A_block = (const float *)((const char *)src3 + blockIdx.y * splitD * src3_nb1);
|
||||||
|
const float *B_block = (const float *)((const char *)src4 + (blockIdx.x * src4_nb3));
|
||||||
|
const float *C_block = (const float *)((const char *)src5 + (blockIdx.x * src5_nb3));
|
||||||
|
float *y_block = (float *)((char *)dst + (blockIdx.x * d_inner * L * sizeof(float)) + blockIdx.y * splitD * sizeof(float));
|
||||||
|
float *s_block = (float *)((char *)dst + s_off + blockIdx.x * src0_nb3 + blockIdx.y * splitD * src0_nb2);
|
||||||
|
|
||||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
||||||
const int bidx = blockIdx.x; // split along B (sequences)
|
|
||||||
const int bidy = blockIdx.y; // split along D (d_inner)
|
|
||||||
const int tid = threadIdx.x;
|
|
||||||
const int wid = tid / 32;
|
|
||||||
const int wtid = tid % 32;
|
|
||||||
|
|
||||||
extern __shared__ float smem[];
|
|
||||||
const int stride_sA = N + 1;
|
|
||||||
const int stride_ss0 = N + 1;
|
|
||||||
float * smem_A = smem;
|
|
||||||
float * smem_s0 = smem_A + splitD * stride_sA;
|
|
||||||
|
|
||||||
const float * s0_block = (const float *) ((const char *) src0 + src6[bidx] * src0_nb3 + bidy * splitD * src0_nb2);
|
|
||||||
const float * x_block = (const float *) ((const char *) src1 + (bidx * src1_nb3) + bidy * splitD * sizeof(float));
|
|
||||||
const float * dt_block = (const float *) ((const char *) src2 + (bidx * src2_nb2) + bidy * splitD * sizeof(float));
|
|
||||||
const float * A_block = (const float *) ((const char *) src3 + bidy * splitD * src3_nb1);
|
|
||||||
const float * B_block = (const float *) ((const char *) src4 + (bidx * src4_nb3));
|
|
||||||
const float * C_block = (const float *) ((const char *) src5 + (bidx * src5_nb3));
|
|
||||||
float * y_block = (float *) ((char *) dst + (bidx * d_inner * L * sizeof(float)) + bidy * splitD * sizeof(float));
|
|
||||||
float * s_block = (float *) ((char *) dst + s_off + bidx * src0_nb3 + bidy * splitD * src0_nb2);
|
|
||||||
|
|
||||||
const int stride_s0 = src0_nb2 / sizeof(float);
|
|
||||||
const int stride_x = src1_nb2 / sizeof(float);
|
const int stride_x = src1_nb2 / sizeof(float);
|
||||||
const int stride_dt = src2_nb1 / sizeof(float);
|
const int stride_dt = src2_nb1 / sizeof(float);
|
||||||
const int stride_A = src3_nb1 / sizeof(float);
|
|
||||||
const int stride_B = src4_nb2 / sizeof(float);
|
const int stride_B = src4_nb2 / sizeof(float);
|
||||||
const int stride_C = src5_nb2 / sizeof(float);
|
const int stride_C = src5_nb2 / sizeof(float);
|
||||||
const int stride_s = stride_s0;
|
|
||||||
const int stride_y = d_inner;
|
const int stride_y = d_inner;
|
||||||
|
|
||||||
// can N not be 16? for example 32?
|
float regA[N];
|
||||||
if (N == 16) {
|
float regs0[N];
|
||||||
#pragma unroll
|
|
||||||
for (size_t i = 0; i < splitD / 4; i += 2) {
|
|
||||||
float value = A_block[(wid * warp_size + i) * stride_A + wtid];
|
|
||||||
// todo: bank conflict
|
|
||||||
// I am always confused with how to use the swizzling method to solve
|
|
||||||
// bank conflit. Hoping somebody can tell me.
|
|
||||||
smem_A[(wid * warp_size + i) * stride_sA + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
|
|
||||||
}
|
|
||||||
#pragma unroll
|
|
||||||
for (size_t i = 0; i < splitD / 4; i += 2) {
|
|
||||||
float value = s0_block[(wid * warp_size + i) * stride_s0 + wtid];
|
|
||||||
smem_s0[(wid * warp_size + i) * stride_ss0 + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
__shared__ float smemB[N];
|
||||||
|
__shared__ float smemC[N];
|
||||||
|
|
||||||
|
#ifdef USE_CUB
|
||||||
|
using BlockLoad = cub::BlockLoad<float, splitD, N, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
||||||
|
using BlockStore = cub::BlockStore<float, splitD, N, cub::BLOCK_STORE_WARP_TRANSPOSE>;
|
||||||
|
|
||||||
|
union CubTempStorage {
|
||||||
|
typename BlockLoad::TempStorage load_temp;
|
||||||
|
typename BlockStore::TempStorage store_temp;
|
||||||
|
};
|
||||||
|
__shared__ CubTempStorage cub_temp_storage;
|
||||||
|
|
||||||
|
BlockLoad(cub_temp_storage.load_temp).Load(A_block, regA);
|
||||||
|
BlockLoad(cub_temp_storage.load_temp).Load(s0_block, regs0);
|
||||||
|
#else
|
||||||
|
const int stride_s0 = src0_nb2 / sizeof(float);
|
||||||
|
const int stride_A = src3_nb1 / sizeof(float);
|
||||||
|
#pragma unroll
|
||||||
|
for (size_t n = 0; n < N; ++n)
|
||||||
|
{
|
||||||
|
regA[n] = A_block[threadIdx.x * stride_A + n];
|
||||||
|
regs0[n] = s0_block[threadIdx.x * stride_s0 + n];
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (size_t i = 0; i < L; i++)
|
||||||
|
{
|
||||||
|
if (threadIdx.x < N)
|
||||||
|
{
|
||||||
|
smemB[threadIdx.x] = B_block[i * stride_B + threadIdx.x];
|
||||||
|
smemC[threadIdx.x] = C_block[i * stride_C + threadIdx.x];
|
||||||
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
for (int64_t i = 0; i < L; i++) {
|
float dt_soft_plus = dt_block[i * stride_dt + threadIdx.x];
|
||||||
float dt_soft_plus = dt_block[i * stride_dt + tid];
|
if (dt_soft_plus <= 20.0f)
|
||||||
if (dt_soft_plus <= 20.0f) {
|
{
|
||||||
dt_soft_plus = log1pf(exp(dt_soft_plus));
|
dt_soft_plus = log1pf(expf(dt_soft_plus));
|
||||||
}
|
}
|
||||||
float x_dt = x_block[i * stride_x + tid] * dt_soft_plus;
|
float x_dt = x_block[i * stride_x + threadIdx.x] * dt_soft_plus;
|
||||||
|
|
||||||
float sumf = 0.0f;
|
float sumf = 0.0f;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (size_t j = 0; j < N; j++) {
|
for (size_t n = 0; n < N; n++)
|
||||||
float state = (smem_s0[tid * stride_ss0 + j] * expf(dt_soft_plus * smem_A[tid * stride_sA + j])) +
|
{
|
||||||
(B_block[i * stride_B + j] * x_dt);
|
float state = regs0[n] * expf(dt_soft_plus * regA[n]) + smemB[n] * x_dt;
|
||||||
sumf += state * C_block[i * stride_C + j];
|
sumf += state * smemC[n];
|
||||||
if (i == L - 1) {
|
regs0[n] = state;
|
||||||
s_block[tid * stride_s + j] = state;
|
|
||||||
} else {
|
|
||||||
smem_s0[tid * stride_ss0 + j] = state;
|
|
||||||
}
|
}
|
||||||
|
y_block[i * stride_y + threadIdx.x] = sumf;
|
||||||
}
|
}
|
||||||
__syncthreads();
|
|
||||||
y_block[i * stride_y + tid] = sumf;
|
#ifdef USE_CUB
|
||||||
|
BlockStore(cub_temp_storage.store_temp).Store(s_block, regs0);
|
||||||
|
#else
|
||||||
|
const int stride_s = stride_s0;
|
||||||
|
#pragma unroll
|
||||||
|
for (size_t n = 0; n < N; ++n)
|
||||||
|
{
|
||||||
|
s_block[threadIdx.x * stride_s + n] = regs0[n];
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
#ifdef __clang__
|
||||||
|
#pragma clang diagnostic pop
|
||||||
|
#endif // __clang__
|
||||||
|
|
||||||
// assumes as many threads as d_state
|
// assumes as many threads as d_state
|
||||||
template <int splitH, int d_state>
|
template <int splitH, int d_state>
|
||||||
|
@ -201,11 +231,11 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
|
||||||
const int src5_nb3, const int64_t s_off, const int64_t d_state, const int64_t head_dim,
|
const int src5_nb3, const int64_t s_off, const int64_t d_state, const int64_t head_dim,
|
||||||
const int64_t n_head, const int64_t n_group, const int64_t n_tok, const int64_t n_seq,
|
const int64_t n_head, const int64_t n_group, const int64_t n_tok, const int64_t n_seq,
|
||||||
cudaStream_t stream) {
|
cudaStream_t stream) {
|
||||||
|
const int threads = 128;
|
||||||
// NOTE: if you change conditions here, be sure to update the corresponding supports_op condition!
|
// NOTE: if you change conditions here, be sure to update the corresponding supports_op condition!
|
||||||
if (src3_nb1 == sizeof(float)) {
|
if (src3_nb1 == sizeof(float)) {
|
||||||
// Mamba-2
|
// Mamba-2
|
||||||
if (d_state == 128) {
|
if (d_state == 128) {
|
||||||
const int threads = 128;
|
|
||||||
GGML_ASSERT(d_state % threads == 0);
|
GGML_ASSERT(d_state % threads == 0);
|
||||||
// NOTE: can be any power of two between 4 and 64
|
// NOTE: can be any power of two between 4 and 64
|
||||||
const int splitH = 16;
|
const int splitH = 16;
|
||||||
|
@ -229,7 +259,6 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
|
||||||
GGML_ABORT("doesn't support d_state!=(128 or 256).");
|
GGML_ABORT("doesn't support d_state!=(128 or 256).");
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
const int threads = 128;
|
|
||||||
// Mamba-1
|
// Mamba-1
|
||||||
GGML_ASSERT(n_head % threads == 0);
|
GGML_ASSERT(n_head % threads == 0);
|
||||||
GGML_ASSERT(head_dim == 1);
|
GGML_ASSERT(head_dim == 1);
|
||||||
|
@ -237,10 +266,63 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
|
||||||
const dim3 blocks(n_seq, (n_head + threads - 1) / threads, 1);
|
const dim3 blocks(n_seq, (n_head + threads - 1) / threads, 1);
|
||||||
const int smem_size = (threads * (d_state + 1) * 2) * sizeof(float);
|
const int smem_size = (threads * (d_state + 1) * 2) * sizeof(float);
|
||||||
if (d_state == 16) {
|
if (d_state == 16) {
|
||||||
ssm_scan_f32<128, 16><<<blocks, threads, smem_size, stream>>>(
|
switch (n_tok)
|
||||||
|
{
|
||||||
|
case 1:
|
||||||
|
ssm_scan_f32<threads, 16, 1><<<blocks, threads, smem_size, stream>>>(
|
||||||
src0, src1, src2, src3, src4, src5, src6, dst,
|
src0, src1, src2, src3, src4, src5, src6, dst,
|
||||||
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
|
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
|
||||||
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
|
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
|
||||||
|
break;
|
||||||
|
case 2:
|
||||||
|
ssm_scan_f32<threads, 16, 2><<<blocks, threads, smem_size, stream>>>(
|
||||||
|
src0, src1, src2, src3, src4, src5, src6, dst,
|
||||||
|
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
|
||||||
|
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
|
||||||
|
break;
|
||||||
|
case 3:
|
||||||
|
ssm_scan_f32<threads, 16, 3><<<blocks, threads, smem_size, stream>>>(
|
||||||
|
src0, src1, src2, src3, src4, src5, src6, dst,
|
||||||
|
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
|
||||||
|
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
|
||||||
|
break;
|
||||||
|
case 4:
|
||||||
|
ssm_scan_f32<threads, 16, 4><<<blocks, threads, smem_size, stream>>>(
|
||||||
|
src0, src1, src2, src3, src4, src5, src6, dst,
|
||||||
|
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
|
||||||
|
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
|
||||||
|
break;
|
||||||
|
case 5:
|
||||||
|
ssm_scan_f32<threads, 16, 5><<<blocks, threads, smem_size, stream>>>(
|
||||||
|
src0, src1, src2, src3, src4, src5, src6, dst,
|
||||||
|
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
|
||||||
|
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
|
||||||
|
break;
|
||||||
|
case 6:
|
||||||
|
ssm_scan_f32<threads, 16, 6><<<blocks, threads, smem_size, stream>>>(
|
||||||
|
src0, src1, src2, src3, src4, src5, src6, dst,
|
||||||
|
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
|
||||||
|
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
|
||||||
|
break;
|
||||||
|
case 7:
|
||||||
|
ssm_scan_f32<threads, 16, 7><<<blocks, threads, smem_size, stream>>>(
|
||||||
|
src0, src1, src2, src3, src4, src5, src6, dst,
|
||||||
|
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
|
||||||
|
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
|
||||||
|
break;
|
||||||
|
case 8:
|
||||||
|
ssm_scan_f32<threads, 16, 8><<<blocks, threads, smem_size, stream>>>(
|
||||||
|
src0, src1, src2, src3, src4, src5, src6, dst,
|
||||||
|
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
|
||||||
|
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
ssm_scan_f32<threads, 16, 0><<<blocks, threads, smem_size, stream>>>(
|
||||||
|
src0, src1, src2, src3, src4, src5, src6, dst,
|
||||||
|
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
|
||||||
|
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
|
||||||
|
break;
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
GGML_ABORT("doesn't support d_state!=16.");
|
GGML_ABORT("doesn't support d_state!=16.");
|
||||||
}
|
}
|
||||||
|
|
|
@ -1119,7 +1119,8 @@ class TensorNameMap:
|
||||||
"model.vision_tower.embeddings.patch_embeddings.projection", # Intern-S1
|
"model.vision_tower.embeddings.patch_embeddings.projection", # Intern-S1
|
||||||
"vpm.embeddings.patch_embedding",
|
"vpm.embeddings.patch_embedding",
|
||||||
"model.vision_model.embeddings.patch_embedding", # SmolVLM
|
"model.vision_model.embeddings.patch_embedding", # SmolVLM
|
||||||
"vision_tower.patch_conv", # pixtral
|
"vision_tower.patch_conv", # pixtral-hf
|
||||||
|
"vision_encoder.patch_conv", # pixtral
|
||||||
"vision_model.patch_embedding.linear", # llama 4
|
"vision_model.patch_embedding.linear", # llama 4
|
||||||
"visual.patch_embed.proj", # qwen2vl
|
"visual.patch_embed.proj", # qwen2vl
|
||||||
),
|
),
|
||||||
|
@ -1138,7 +1139,8 @@ class TensorNameMap:
|
||||||
"vpm.encoder.layers.{bid}.self_attn.q_proj",
|
"vpm.encoder.layers.{bid}.self_attn.q_proj",
|
||||||
"model.vision_model.encoder.layers.{bid}.self_attn.q_proj", # SmolVLM
|
"model.vision_model.encoder.layers.{bid}.self_attn.q_proj", # SmolVLM
|
||||||
"vision_model.model.layers.{bid}.self_attn.q_proj", # llama4
|
"vision_model.model.layers.{bid}.self_attn.q_proj", # llama4
|
||||||
"vision_tower.transformer.layers.{bid}.attention.q_proj", # pixtral
|
"vision_tower.transformer.layers.{bid}.attention.q_proj", # pixtral-hf
|
||||||
|
"vision_encoder.transformer.layers.{bid}.attention.wq", # pixtral
|
||||||
"visual.blocks.{bid}.attn.q", # qwen2vl, generated
|
"visual.blocks.{bid}.attn.q", # qwen2vl, generated
|
||||||
),
|
),
|
||||||
|
|
||||||
|
@ -1153,7 +1155,8 @@ class TensorNameMap:
|
||||||
"vpm.encoder.layers.{bid}.self_attn.k_proj",
|
"vpm.encoder.layers.{bid}.self_attn.k_proj",
|
||||||
"model.vision_model.encoder.layers.{bid}.self_attn.k_proj", # SmolVLM
|
"model.vision_model.encoder.layers.{bid}.self_attn.k_proj", # SmolVLM
|
||||||
"vision_model.model.layers.{bid}.self_attn.k_proj", # llama4
|
"vision_model.model.layers.{bid}.self_attn.k_proj", # llama4
|
||||||
"vision_tower.transformer.layers.{bid}.attention.k_proj", # pixtral
|
"vision_tower.transformer.layers.{bid}.attention.k_proj", # pixtral-hf
|
||||||
|
"vision_encoder.transformer.layers.{bid}.attention.wk", # pixtral
|
||||||
"visual.blocks.{bid}.attn.k", # qwen2vl, generated
|
"visual.blocks.{bid}.attn.k", # qwen2vl, generated
|
||||||
),
|
),
|
||||||
|
|
||||||
|
@ -1168,7 +1171,8 @@ class TensorNameMap:
|
||||||
"vpm.encoder.layers.{bid}.self_attn.v_proj",
|
"vpm.encoder.layers.{bid}.self_attn.v_proj",
|
||||||
"model.vision_model.encoder.layers.{bid}.self_attn.v_proj", # SmolVLM
|
"model.vision_model.encoder.layers.{bid}.self_attn.v_proj", # SmolVLM
|
||||||
"vision_model.model.layers.{bid}.self_attn.v_proj", # llama4
|
"vision_model.model.layers.{bid}.self_attn.v_proj", # llama4
|
||||||
"vision_tower.transformer.layers.{bid}.attention.v_proj", # pixtral
|
"vision_tower.transformer.layers.{bid}.attention.v_proj", # pixtral-hf
|
||||||
|
"vision_encoder.transformer.layers.{bid}.attention.wv", # pixtral
|
||||||
"visual.blocks.{bid}.attn.v", # qwen2vl, generated
|
"visual.blocks.{bid}.attn.v", # qwen2vl, generated
|
||||||
),
|
),
|
||||||
|
|
||||||
|
@ -1178,7 +1182,8 @@ class TensorNameMap:
|
||||||
"model.vision_tower.encoder.layer.{bid}.layernorm_before", # Intern-S1
|
"model.vision_tower.encoder.layer.{bid}.layernorm_before", # Intern-S1
|
||||||
"vpm.encoder.layers.{bid}.layer_norm1",
|
"vpm.encoder.layers.{bid}.layer_norm1",
|
||||||
"model.vision_model.encoder.layers.{bid}.layer_norm1", # SmolVLM
|
"model.vision_model.encoder.layers.{bid}.layer_norm1", # SmolVLM
|
||||||
"vision_tower.transformer.layers.{bid}.attention_norm", # pixtral
|
"vision_tower.transformer.layers.{bid}.attention_norm", # pixtral-hf
|
||||||
|
"vision_encoder.transformer.layers.{bid}.attention_norm", # pixtral
|
||||||
"vision_model.model.layers.{bid}.input_layernorm", # llama4
|
"vision_model.model.layers.{bid}.input_layernorm", # llama4
|
||||||
"visual.blocks.{bid}.norm1", # qwen2vl
|
"visual.blocks.{bid}.norm1", # qwen2vl
|
||||||
),
|
),
|
||||||
|
@ -1190,7 +1195,8 @@ class TensorNameMap:
|
||||||
"vpm.encoder.layers.{bid}.self_attn.out_proj",
|
"vpm.encoder.layers.{bid}.self_attn.out_proj",
|
||||||
"model.vision_model.encoder.layers.{bid}.self_attn.out_proj", # SmolVLM
|
"model.vision_model.encoder.layers.{bid}.self_attn.out_proj", # SmolVLM
|
||||||
"vision_model.model.layers.{bid}.self_attn.o_proj", # llama4
|
"vision_model.model.layers.{bid}.self_attn.o_proj", # llama4
|
||||||
"vision_tower.transformer.layers.{bid}.attention.o_proj", # pixtral
|
"vision_tower.transformer.layers.{bid}.attention.o_proj", # pixtral-hf
|
||||||
|
"vision_encoder.transformer.layers.{bid}.attention.wo", # pixtral
|
||||||
"visual.blocks.{bid}.attn.proj", # qwen2vl
|
"visual.blocks.{bid}.attn.proj", # qwen2vl
|
||||||
),
|
),
|
||||||
|
|
||||||
|
@ -1201,7 +1207,8 @@ class TensorNameMap:
|
||||||
"vpm.encoder.layers.{bid}.layer_norm2",
|
"vpm.encoder.layers.{bid}.layer_norm2",
|
||||||
"model.vision_model.encoder.layers.{bid}.layer_norm2", # SmolVLM
|
"model.vision_model.encoder.layers.{bid}.layer_norm2", # SmolVLM
|
||||||
"vision_model.model.layers.{bid}.post_attention_layernorm", # llama4
|
"vision_model.model.layers.{bid}.post_attention_layernorm", # llama4
|
||||||
"vision_tower.transformer.layers.{bid}.ffn_norm", # pixtral
|
"vision_tower.transformer.layers.{bid}.ffn_norm", # pixtral-hf
|
||||||
|
"vision_encoder.transformer.layers.{bid}.ffn_norm", # pixtral
|
||||||
"visual.blocks.{bid}.norm2", # qwen2vl
|
"visual.blocks.{bid}.norm2", # qwen2vl
|
||||||
),
|
),
|
||||||
|
|
||||||
|
@ -1210,14 +1217,16 @@ class TensorNameMap:
|
||||||
"model.vision_tower.encoder.layer.{bid}.mlp.fc1", # Intern-S1
|
"model.vision_tower.encoder.layer.{bid}.mlp.fc1", # Intern-S1
|
||||||
"vpm.encoder.layers.{bid}.mlp.fc1",
|
"vpm.encoder.layers.{bid}.mlp.fc1",
|
||||||
"model.vision_model.encoder.layers.{bid}.mlp.fc1", # SmolVLM, gemma3
|
"model.vision_model.encoder.layers.{bid}.mlp.fc1", # SmolVLM, gemma3
|
||||||
"vision_tower.transformer.layers.{bid}.feed_forward.up_proj", # pixtral
|
"vision_tower.transformer.layers.{bid}.feed_forward.up_proj", # pixtral-hf
|
||||||
|
"vision_encoder.transformer.layers.{bid}.feed_forward.w3", # pixtral
|
||||||
"vision_model.model.layers.{bid}.mlp.fc1", # llama4
|
"vision_model.model.layers.{bid}.mlp.fc1", # llama4
|
||||||
"visual.blocks.{bid}.mlp.fc1", # qwen2vl
|
"visual.blocks.{bid}.mlp.fc1", # qwen2vl
|
||||||
"visual.blocks.{bid}.mlp.up_proj", # qwen2.5vl
|
"visual.blocks.{bid}.mlp.up_proj", # qwen2.5vl
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.V_ENC_FFN_GATE: (
|
MODEL_TENSOR.V_ENC_FFN_GATE: (
|
||||||
"vision_tower.transformer.layers.{bid}.feed_forward.gate_proj", # pixtral
|
"vision_tower.transformer.layers.{bid}.feed_forward.gate_proj", # pixtral-hf
|
||||||
|
"vision_encoder.transformer.layers.{bid}.feed_forward.w1", # pixtral
|
||||||
"visual.blocks.{bid}.mlp.gate_proj", # qwen2.5vl
|
"visual.blocks.{bid}.mlp.gate_proj", # qwen2.5vl
|
||||||
),
|
),
|
||||||
|
|
||||||
|
@ -1226,7 +1235,8 @@ class TensorNameMap:
|
||||||
"model.vision_tower.encoder.layer.{bid}.mlp.fc2", # Intern-S1
|
"model.vision_tower.encoder.layer.{bid}.mlp.fc2", # Intern-S1
|
||||||
"vpm.encoder.layers.{bid}.mlp.fc2",
|
"vpm.encoder.layers.{bid}.mlp.fc2",
|
||||||
"model.vision_model.encoder.layers.{bid}.mlp.fc2", # SmolVLM, gemma3
|
"model.vision_model.encoder.layers.{bid}.mlp.fc2", # SmolVLM, gemma3
|
||||||
"vision_tower.transformer.layers.{bid}.feed_forward.down_proj", # pixtral
|
"vision_tower.transformer.layers.{bid}.feed_forward.down_proj", # pixtral-hf
|
||||||
|
"vision_encoder.transformer.layers.{bid}.feed_forward.w2", # pixtral
|
||||||
"vision_model.model.layers.{bid}.mlp.fc2", # llama4
|
"vision_model.model.layers.{bid}.mlp.fc2", # llama4
|
||||||
"visual.blocks.{bid}.mlp.fc2", # qwen2vl
|
"visual.blocks.{bid}.mlp.fc2", # qwen2vl
|
||||||
"visual.blocks.{bid}.mlp.down_proj", # qwen2.5vl
|
"visual.blocks.{bid}.mlp.down_proj", # qwen2.5vl
|
||||||
|
@ -1244,7 +1254,8 @@ class TensorNameMap:
|
||||||
|
|
||||||
MODEL_TENSOR.V_PRE_NORM: (
|
MODEL_TENSOR.V_PRE_NORM: (
|
||||||
"vision_tower.vision_model.pre_layrnorm",
|
"vision_tower.vision_model.pre_layrnorm",
|
||||||
"vision_tower.ln_pre", # pixtral
|
"vision_tower.ln_pre", # pixtral-hf
|
||||||
|
"vision_encoder.ln_pre", # pixtral
|
||||||
"vision_model.layernorm_pre", # llama4
|
"vision_model.layernorm_pre", # llama4
|
||||||
),
|
),
|
||||||
|
|
||||||
|
@ -1261,6 +1272,7 @@ class TensorNameMap:
|
||||||
|
|
||||||
MODEL_TENSOR.V_MM_INP_NORM: (
|
MODEL_TENSOR.V_MM_INP_NORM: (
|
||||||
"multi_modal_projector.norm",
|
"multi_modal_projector.norm",
|
||||||
|
"pre_mm_projector_norm",
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.V_MM_SOFT_EMB_NORM: (
|
MODEL_TENSOR.V_MM_SOFT_EMB_NORM: (
|
||||||
|
@ -1316,7 +1328,8 @@ class TensorNameMap:
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.V_MM_PATCH_MERGER: (
|
MODEL_TENSOR.V_MM_PATCH_MERGER: (
|
||||||
"multi_modal_projector.patch_merger.merging_layer", # mistral small 3.1
|
"multi_modal_projector.patch_merger.merging_layer", # mistral small 3.1 - hf
|
||||||
|
"patch_merger.merging_layer", # mistral
|
||||||
),
|
),
|
||||||
|
|
||||||
# audio (mtmd)
|
# audio (mtmd)
|
||||||
|
|
|
@ -145,7 +145,11 @@ class SafetensorRemote:
|
||||||
tensors[key] = val
|
tensors[key] = val
|
||||||
return tensors
|
return tensors
|
||||||
|
|
||||||
raise ValueError(f"Model {model_id} does not have any safetensor files")
|
raise ValueError(
|
||||||
|
f"No safetensor file has been found for model {model_id}."
|
||||||
|
"If the repo has safetensor files, make sure the model is public or you have a "
|
||||||
|
"valid Hugging Face token set in the environment variable HF_TOKEN."
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_list_tensors(cls, url: str) -> dict[str, RemoteTensor]:
|
def get_list_tensors(cls, url: str) -> dict[str, RemoteTensor]:
|
||||||
|
|
|
@ -223,12 +223,7 @@ void llama_kv_cache_unified::clear(bool data) {
|
||||||
}
|
}
|
||||||
|
|
||||||
bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
||||||
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
|
GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()));
|
||||||
|
|
||||||
auto & cells = v_cells[seq_to_stream[seq_id]];
|
|
||||||
auto & head = v_heads[seq_to_stream[seq_id]];
|
|
||||||
|
|
||||||
uint32_t new_head = cells.size();
|
|
||||||
|
|
||||||
if (p0 < 0) {
|
if (p0 < 0) {
|
||||||
p0 = 0;
|
p0 = 0;
|
||||||
|
@ -239,6 +234,11 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
|
||||||
}
|
}
|
||||||
|
|
||||||
if (seq_id >= 0) {
|
if (seq_id >= 0) {
|
||||||
|
auto & cells = v_cells[seq_to_stream[seq_id]];
|
||||||
|
auto & head = v_heads[seq_to_stream[seq_id]];
|
||||||
|
|
||||||
|
uint32_t new_head = cells.size();
|
||||||
|
|
||||||
for (uint32_t i = 0; i < cells.size(); ++i) {
|
for (uint32_t i = 0; i < cells.size(); ++i) {
|
||||||
if (!cells.pos_in(i, p0, p1)) {
|
if (!cells.pos_in(i, p0, p1)) {
|
||||||
continue;
|
continue;
|
||||||
|
@ -250,8 +250,19 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If we freed up a slot, set head to it so searching can start there.
|
||||||
|
if (new_head != cells.size() && new_head < head) {
|
||||||
|
head = new_head;
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
// match any sequence
|
// match any sequence
|
||||||
|
for (uint32_t s = 0; s < n_stream; ++s) {
|
||||||
|
auto & cells = v_cells[s];
|
||||||
|
auto & head = v_heads[s];
|
||||||
|
|
||||||
|
uint32_t new_head = cells.size();
|
||||||
|
|
||||||
for (uint32_t i = 0; i < cells.size(); ++i) {
|
for (uint32_t i = 0; i < cells.size(); ++i) {
|
||||||
if (!cells.pos_in(i, p0, p1)) {
|
if (!cells.pos_in(i, p0, p1)) {
|
||||||
continue;
|
continue;
|
||||||
|
@ -263,12 +274,13 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
|
||||||
new_head = i;
|
new_head = i;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// If we freed up a slot, set head to it so searching can start there.
|
// If we freed up a slot, set head to it so searching can start there.
|
||||||
if (new_head != cells.size() && new_head < head) {
|
if (new_head != cells.size() && new_head < head) {
|
||||||
head = new_head;
|
head = new_head;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -738,13 +750,16 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch, bool cont) const {
|
llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch, bool cont) const {
|
||||||
|
|
||||||
if (debug > 0) {
|
if (debug > 0) {
|
||||||
const auto & cells = v_cells[seq_to_stream[1]];
|
for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
|
||||||
|
const auto seq_id = ubatch.seq_id_unq[s];
|
||||||
|
const auto stream_id = seq_to_stream[seq_id];
|
||||||
|
const auto & cells = v_cells[stream_id];
|
||||||
|
const uint32_t head_cur = v_heads[stream_id];
|
||||||
|
|
||||||
const uint32_t head_cur = v_heads[1];
|
LLAMA_LOG_DEBUG("%s: stream[%d], n = %5d, used = %5d, head = %5d, size = %5d, n_swa = %5d\n",
|
||||||
|
__func__, stream_id, cells.used_max_p1(), cells.get_used(), head_cur, get_size(), n_swa);
|
||||||
LLAMA_LOG_DEBUG("%s: n = %5d, used = %5d, head = %5d, size = %5d, n_swa = %5d\n",
|
|
||||||
__func__, cells.used_max_p1(), cells.get_used(), head_cur, get_size(), n_swa);
|
|
||||||
|
|
||||||
if ((debug == 2 && n_swa > 0) || debug > 2) {
|
if ((debug == 2 && n_swa > 0) || debug > 2) {
|
||||||
std::string ss;
|
std::string ss;
|
||||||
|
@ -797,7 +812,8 @@ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
LLAMA_LOG_DEBUG("%s: min[%d] = %5d, max[%d] = %5d\n", __func__, s, cells.seq_pos_min(s), s, cells.seq_pos_max(s));
|
LLAMA_LOG_DEBUG("%s: stream[%d] min[%d] = %5d, max[%d] = %5d\n", __func__, stream_id, s, cells.seq_pos_min(s), s, cells.seq_pos_max(s));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -44,6 +44,7 @@
|
||||||
#define KEY_WIN_ATTN_PATTERN "clip.vision.n_wa_pattern"
|
#define KEY_WIN_ATTN_PATTERN "clip.vision.n_wa_pattern"
|
||||||
#define KEY_ATTN_WINDOW_SIZE "clip.vision.window_size"
|
#define KEY_ATTN_WINDOW_SIZE "clip.vision.window_size"
|
||||||
#define KEY_MINICPMV_VERSION "clip.minicpmv_version"
|
#define KEY_MINICPMV_VERSION "clip.minicpmv_version"
|
||||||
|
#define KEY_MINICPMV_QUERY_NUM "clip.minicpmv_query_num"
|
||||||
|
|
||||||
// audio-specific
|
// audio-specific
|
||||||
#define KEY_A_NUM_MEL_BINS "clip.audio.num_mel_bins"
|
#define KEY_A_NUM_MEL_BINS "clip.audio.num_mel_bins"
|
||||||
|
|
|
@ -219,6 +219,7 @@ struct clip_hparams {
|
||||||
// legacy
|
// legacy
|
||||||
bool has_llava_projector = false;
|
bool has_llava_projector = false;
|
||||||
int minicpmv_version = 0;
|
int minicpmv_version = 0;
|
||||||
|
int32_t minicpmv_query_num = 0; // MiniCPM-V query number
|
||||||
};
|
};
|
||||||
|
|
||||||
struct clip_layer {
|
struct clip_layer {
|
||||||
|
@ -891,21 +892,8 @@ struct clip_graph {
|
||||||
int n_embd = clip_n_mmproj_embd(ctx);
|
int n_embd = clip_n_mmproj_embd(ctx);
|
||||||
const int d_head = 128;
|
const int d_head = 128;
|
||||||
int n_head = n_embd/d_head;
|
int n_head = n_embd/d_head;
|
||||||
int num_query = 96;
|
// Use actual config value if available, otherwise fall back to hardcoded values
|
||||||
if (ctx->model.hparams.minicpmv_version == 2) {
|
int num_query = ctx->model.hparams.minicpmv_query_num;
|
||||||
// MiniCPM-V 2.5
|
|
||||||
num_query = 96;
|
|
||||||
} else if (ctx->model.hparams.minicpmv_version == 3) {
|
|
||||||
// MiniCPM-V 2.6
|
|
||||||
num_query = 64;
|
|
||||||
} else if (ctx->model.hparams.minicpmv_version == 4) {
|
|
||||||
// MiniCPM-o 2.6
|
|
||||||
num_query = 64;
|
|
||||||
} else if (ctx->model.hparams.minicpmv_version == 5) {
|
|
||||||
// MiniCPM-V 4.0
|
|
||||||
num_query = 64;
|
|
||||||
}
|
|
||||||
|
|
||||||
ggml_tensor * Q = ggml_add(ctx0,
|
ggml_tensor * Q = ggml_add(ctx0,
|
||||||
ggml_mul_mat(ctx0, model.mm_model_attn_q_w, q),
|
ggml_mul_mat(ctx0, model.mm_model_attn_q_w, q),
|
||||||
model.mm_model_attn_q_b);
|
model.mm_model_attn_q_b);
|
||||||
|
@ -2178,7 +2166,19 @@ struct clip_model_loader {
|
||||||
get_u32(KEY_PATCH_SIZE, hparams.patch_size);
|
get_u32(KEY_PATCH_SIZE, hparams.patch_size);
|
||||||
get_u32(KEY_IMAGE_CROP_RESOLUTION, hparams.image_crop_resolution, false);
|
get_u32(KEY_IMAGE_CROP_RESOLUTION, hparams.image_crop_resolution, false);
|
||||||
get_i32(KEY_MINICPMV_VERSION, hparams.minicpmv_version, false); // legacy
|
get_i32(KEY_MINICPMV_VERSION, hparams.minicpmv_version, false); // legacy
|
||||||
|
get_u32(KEY_MINICPMV_QUERY_NUM, hparams.minicpmv_query_num, false);
|
||||||
|
if (hparams.minicpmv_query_num == 0) {
|
||||||
|
// Fallback to hardcoded values for legacy models
|
||||||
|
if (hparams.minicpmv_version == 3) {
|
||||||
|
hparams.minicpmv_query_num = 64;
|
||||||
|
} else if (hparams.minicpmv_version == 4) {
|
||||||
|
hparams.minicpmv_query_num = 64;
|
||||||
|
} else if (hparams.minicpmv_version == 5) {
|
||||||
|
hparams.minicpmv_query_num = 64;
|
||||||
|
} else {
|
||||||
|
hparams.minicpmv_query_num = 96;
|
||||||
|
}
|
||||||
|
}
|
||||||
} else if (is_audio) {
|
} else if (is_audio) {
|
||||||
get_u32(KEY_A_NUM_MEL_BINS, hparams.n_mel_bins);
|
get_u32(KEY_A_NUM_MEL_BINS, hparams.n_mel_bins);
|
||||||
|
|
||||||
|
@ -3732,14 +3732,16 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
|
||||||
} break;
|
} break;
|
||||||
case PROJECTOR_TYPE_MINICPMV:
|
case PROJECTOR_TYPE_MINICPMV:
|
||||||
{
|
{
|
||||||
|
// Use actual config value if available, otherwise fall back to hardcoded values
|
||||||
|
if (params.minicpmv_query_num > 0) {
|
||||||
|
n_patches_sq = params.minicpmv_query_num;
|
||||||
|
} else {
|
||||||
|
// Fallback to hardcoded values for legacy models
|
||||||
if (params.minicpmv_version == 2) {
|
if (params.minicpmv_version == 2) {
|
||||||
// MiniCPM-V 2.5
|
|
||||||
n_patches_sq = 96;
|
n_patches_sq = 96;
|
||||||
} else if (params.minicpmv_version == 3) {
|
} else if (params.minicpmv_version == 3) {
|
||||||
// MiniCPM-V 2.6
|
|
||||||
n_patches_sq = 64;
|
n_patches_sq = 64;
|
||||||
} else if (params.minicpmv_version == 4) {
|
} else if (params.minicpmv_version == 4) {
|
||||||
// MiniCPM-o 2.6
|
|
||||||
n_patches_sq = 64;
|
n_patches_sq = 64;
|
||||||
} else if (params.minicpmv_version == 5) {
|
} else if (params.minicpmv_version == 5) {
|
||||||
// MiniCPM-V 4.0
|
// MiniCPM-V 4.0
|
||||||
|
@ -3747,6 +3749,7 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
|
||||||
} else {
|
} else {
|
||||||
GGML_ABORT("Unknown minicpmv version");
|
GGML_ABORT("Unknown minicpmv version");
|
||||||
}
|
}
|
||||||
|
}
|
||||||
} break;
|
} break;
|
||||||
case PROJECTOR_TYPE_QWEN2VL:
|
case PROJECTOR_TYPE_QWEN2VL:
|
||||||
case PROJECTOR_TYPE_QWEN25VL:
|
case PROJECTOR_TYPE_QWEN25VL:
|
||||||
|
@ -4458,7 +4461,6 @@ bool clip_model_quantize(const char * fname_inp, const char * fname_out, const i
|
||||||
}
|
}
|
||||||
|
|
||||||
int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
|
int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
|
||||||
const auto & hparams = ctx->model.hparams;
|
|
||||||
switch (ctx->model.proj_type) {
|
switch (ctx->model.proj_type) {
|
||||||
case PROJECTOR_TYPE_LDP:
|
case PROJECTOR_TYPE_LDP:
|
||||||
return ctx->model.mm_model_block_1_block_2_1_b->ne[0];
|
return ctx->model.mm_model_block_1_block_2_1_b->ne[0];
|
||||||
|
@ -4470,20 +4472,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
|
||||||
case PROJECTOR_TYPE_MLP_NORM:
|
case PROJECTOR_TYPE_MLP_NORM:
|
||||||
return ctx->model.mm_3_b->ne[0];
|
return ctx->model.mm_3_b->ne[0];
|
||||||
case PROJECTOR_TYPE_MINICPMV:
|
case PROJECTOR_TYPE_MINICPMV:
|
||||||
if (hparams.minicpmv_version == 2) {
|
return ctx->model.mm_model_proj->ne[0];
|
||||||
// MiniCPM-V 2.5
|
|
||||||
return 4096;
|
|
||||||
} else if (hparams.minicpmv_version == 3) {
|
|
||||||
// MiniCPM-V 2.6
|
|
||||||
return 3584;
|
|
||||||
} else if (hparams.minicpmv_version == 4) {
|
|
||||||
// MiniCPM-o 2.6
|
|
||||||
return 3584;
|
|
||||||
} else if (hparams.minicpmv_version == 5) {
|
|
||||||
// MiniCPM-V 4.0
|
|
||||||
return 2560;
|
|
||||||
}
|
|
||||||
GGML_ABORT("Unknown minicpmv version");
|
|
||||||
case PROJECTOR_TYPE_GLM_EDGE:
|
case PROJECTOR_TYPE_GLM_EDGE:
|
||||||
return ctx->model.mm_model_mlp_3_w->ne[1];
|
return ctx->model.mm_model_mlp_3_w->ne[1];
|
||||||
case PROJECTOR_TYPE_QWEN2VL:
|
case PROJECTOR_TYPE_QWEN2VL:
|
||||||
|
|
|
@ -517,6 +517,16 @@ if args.use_f32:
|
||||||
# output in the same directory as the model if output_dir is None
|
# output in the same directory as the model if output_dir is None
|
||||||
dir_model = args.model_dir
|
dir_model = args.model_dir
|
||||||
|
|
||||||
|
# Read config.json to get actual model configuration
|
||||||
|
config_path = os.path.join(dir_model, "config.json")
|
||||||
|
model_config = {}
|
||||||
|
if os.path.isfile(config_path):
|
||||||
|
with open(config_path, "r", encoding="utf-8") as f:
|
||||||
|
model_config = json.load(f)
|
||||||
|
print(f"Loaded config from {config_path}")
|
||||||
|
else:
|
||||||
|
print(f"Warning: config.json not found at {config_path}")
|
||||||
|
|
||||||
# If minicpmv_projector is not specified but the default path exists, use the default path
|
# If minicpmv_projector is not specified but the default path exists, use the default path
|
||||||
if args.minicpmv_projector is None:
|
if args.minicpmv_projector is None:
|
||||||
default_projector_path = os.path.join(dir_model, "minicpmv.projector")
|
default_projector_path = os.path.join(dir_model, "minicpmv.projector")
|
||||||
|
@ -555,25 +565,50 @@ if args.use_f32:
|
||||||
# processor = CLIPProcessor.from_pretrained(dir_model)
|
# processor = CLIPProcessor.from_pretrained(dir_model)
|
||||||
|
|
||||||
minicpmv_version = args.minicpmv_version
|
minicpmv_version = args.minicpmv_version
|
||||||
emb_dim = 4096
|
|
||||||
block_count = 26
|
# Use actual config values instead of hardcoded ones
|
||||||
if minicpmv_version == 1: # MiniCPM-V 2.0
|
if model_config:
|
||||||
|
# For the projector/resampler, use the main model's hidden_size
|
||||||
|
emb_dim = model_config.get("hidden_size", 1536)
|
||||||
|
|
||||||
|
# For the vision model, use vision_config values
|
||||||
|
vision_config_dict = model_config.get("vision_config", {})
|
||||||
|
default_vision_config = {
|
||||||
|
"hidden_size": vision_config_dict.get("hidden_size", 1152),
|
||||||
|
"image_size": vision_config_dict.get("image_size", 980),
|
||||||
|
"intermediate_size": vision_config_dict.get("intermediate_size", 4304),
|
||||||
|
"model_type": vision_config_dict.get("model_type", "siglip"),
|
||||||
|
"num_attention_heads": vision_config_dict.get("num_attention_heads", 16),
|
||||||
|
"num_hidden_layers": vision_config_dict.get("num_hidden_layers", 27),
|
||||||
|
"patch_size": vision_config_dict.get("patch_size", 14),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Use vision model's num_hidden_layers for block_count
|
||||||
|
block_count = vision_config_dict.get("num_hidden_layers", 27)
|
||||||
|
|
||||||
|
print(f"Using config values: emb_dim={emb_dim}, block_count={block_count}")
|
||||||
|
print(f"Vision config: {default_vision_config}")
|
||||||
|
else:
|
||||||
|
# Fallback to original hardcoded logic if config.json not found
|
||||||
|
emb_dim = 4096
|
||||||
|
block_count = 26
|
||||||
|
if minicpmv_version == 1:
|
||||||
emb_dim = 2304
|
emb_dim = 2304
|
||||||
block_count = 26
|
block_count = 26
|
||||||
elif minicpmv_version == 2: # MiniCPM-V 2.5
|
elif minicpmv_version == 2:
|
||||||
emb_dim = 4096
|
emb_dim = 4096
|
||||||
block_count = 27
|
block_count = 27
|
||||||
elif minicpmv_version == 3: # MiniCPM-V 2.6
|
elif minicpmv_version == 3:
|
||||||
emb_dim = 3584
|
emb_dim = 3584
|
||||||
block_count = 27
|
block_count = 27
|
||||||
elif minicpmv_version == 4: # MiniCPM-o 2.6
|
elif minicpmv_version == 4:
|
||||||
emb_dim = 3584
|
emb_dim = 3584
|
||||||
block_count = 27
|
block_count = 27
|
||||||
elif minicpmv_version == 5: # MiniCPM-V 4.0
|
elif minicpmv_version == 5:
|
||||||
emb_dim = 2560
|
emb_dim = 2560
|
||||||
block_count = 27
|
block_count = 27
|
||||||
|
|
||||||
default_vision_config = {
|
default_vision_config = {
|
||||||
"hidden_size": 1152,
|
"hidden_size": 1152,
|
||||||
"image_size": 980,
|
"image_size": 980,
|
||||||
"intermediate_size": 4304,
|
"intermediate_size": 4304,
|
||||||
|
@ -585,7 +620,7 @@ default_vision_config = {
|
||||||
|
|
||||||
vision_config = Idefics2VisionConfig(**default_vision_config)
|
vision_config = Idefics2VisionConfig(**default_vision_config)
|
||||||
model = Idefics2VisionTransformer(vision_config)
|
model = Idefics2VisionTransformer(vision_config)
|
||||||
if minicpmv_version == 3:
|
if minicpmv_version == 3 or (model_config and model_config.get("vision_config", {}).get("model_type") == "siglip"):
|
||||||
vision_config = SiglipVisionConfig(**default_vision_config)
|
vision_config = SiglipVisionConfig(**default_vision_config)
|
||||||
model = SiglipVisionTransformer(vision_config)
|
model = SiglipVisionTransformer(vision_config)
|
||||||
elif minicpmv_version == 4:
|
elif minicpmv_version == 4:
|
||||||
|
@ -644,16 +679,27 @@ else:
|
||||||
fout.add_description("two-tower CLIP model")
|
fout.add_description("two-tower CLIP model")
|
||||||
|
|
||||||
if has_vision_encoder:
|
if has_vision_encoder:
|
||||||
# vision_model hparams
|
# vision_model hparams - use actual config values
|
||||||
fout.add_uint32("clip.vision.image_size", 448)
|
vision_image_size = model_config.get("image_size", 448) if model_config else 448
|
||||||
fout.add_uint32("clip.vision.patch_size", 14)
|
vision_patch_size = default_vision_config.get("patch_size", 14)
|
||||||
fout.add_uint32(add_key_str(KEY_EMBEDDING_LENGTH, VISION), 1152)
|
vision_hidden_size = default_vision_config.get("hidden_size", 1152)
|
||||||
fout.add_uint32(add_key_str(KEY_FEED_FORWARD_LENGTH, VISION), 4304)
|
vision_intermediate_size = default_vision_config.get("intermediate_size", 4304)
|
||||||
|
vision_attention_heads = default_vision_config.get("num_attention_heads", 16)
|
||||||
|
|
||||||
|
fout.add_uint32("clip.vision.image_size", vision_image_size)
|
||||||
|
fout.add_uint32("clip.vision.patch_size", vision_patch_size)
|
||||||
|
fout.add_uint32(add_key_str(KEY_EMBEDDING_LENGTH, VISION), vision_hidden_size)
|
||||||
|
fout.add_uint32(add_key_str(KEY_FEED_FORWARD_LENGTH, VISION), vision_intermediate_size)
|
||||||
fout.add_uint32("clip.vision.projection_dim", 0)
|
fout.add_uint32("clip.vision.projection_dim", 0)
|
||||||
fout.add_uint32(add_key_str(KEY_ATTENTION_HEAD_COUNT, VISION), 16)
|
fout.add_uint32(add_key_str(KEY_ATTENTION_HEAD_COUNT, VISION), vision_attention_heads)
|
||||||
fout.add_float32(add_key_str(KEY_ATTENTION_LAYERNORM_EPS, VISION), 1e-6)
|
fout.add_float32(add_key_str(KEY_ATTENTION_LAYERNORM_EPS, VISION), 1e-6)
|
||||||
fout.add_uint32(add_key_str(KEY_BLOCK_COUNT, VISION), block_count)
|
fout.add_uint32(add_key_str(KEY_BLOCK_COUNT, VISION), block_count)
|
||||||
|
|
||||||
|
# Add MiniCPM-V specific parameters
|
||||||
|
query_num = model_config.get("query_num", 0) if model_config else 0
|
||||||
|
resampler_emb_dim = model_config.get("hidden_size", 0) if model_config else 0
|
||||||
|
fout.add_uint32("clip.minicpmv_query_num", query_num)
|
||||||
|
|
||||||
if processor is not None:
|
if processor is not None:
|
||||||
image_mean = processor.image_processor.image_mean if args.image_mean is None or args.image_mean == default_image_mean else args.image_mean
|
image_mean = processor.image_processor.image_mean if args.image_mean is None or args.image_mean == default_image_mean else args.image_mean
|
||||||
image_std = processor.image_processor.image_std if args.image_std is None or args.image_std == default_image_std else args.image_std
|
image_std = processor.image_processor.image_std if args.image_std is None or args.image_std == default_image_std else args.image_std
|
||||||
|
|
|
@ -16,6 +16,8 @@ mm_tensors = [k for k, v in checkpoint.items() if k.startswith("resampler")]
|
||||||
|
|
||||||
# store these tensors in a new dictionary and torch.save them
|
# store these tensors in a new dictionary and torch.save them
|
||||||
projector = {name: checkpoint[name].float() for name in mm_tensors}
|
projector = {name: checkpoint[name].float() for name in mm_tensors}
|
||||||
|
if 'resampler.proj' in projector.keys() and hasattr(model.llm.config,'scale_emb') is True:
|
||||||
|
projector['resampler.proj'] = projector['resampler.proj'] / model.llm.config.scale_emb
|
||||||
torch.save(projector, f"{args.model}/minicpmv.projector")
|
torch.save(projector, f"{args.model}/minicpmv.projector")
|
||||||
|
|
||||||
clip_tensors = [k for k, v in checkpoint.items() if k.startswith("vpm")]
|
clip_tensors = [k for k, v in checkpoint.items() if k.startswith("vpm")]
|
||||||
|
|
Binary file not shown.
|
@ -383,8 +383,12 @@ struct server_task {
|
||||||
} else {
|
} else {
|
||||||
params.oaicompat_chat_syntax.format = defaults.oaicompat_chat_syntax.format;
|
params.oaicompat_chat_syntax.format = defaults.oaicompat_chat_syntax.format;
|
||||||
}
|
}
|
||||||
params.oaicompat_chat_syntax.reasoning_format = params_base.reasoning_format;
|
common_reasoning_format reasoning_format = params_base.reasoning_format;
|
||||||
params.oaicompat_chat_syntax.reasoning_in_content = params.stream && (params_base.reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY);
|
if (data.contains("reasoning_format")) {
|
||||||
|
reasoning_format = common_reasoning_format_from_name(data.at("reasoning_format").get<std::string>());
|
||||||
|
}
|
||||||
|
params.oaicompat_chat_syntax.reasoning_format = reasoning_format;
|
||||||
|
params.oaicompat_chat_syntax.reasoning_in_content = params.stream && (reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY);
|
||||||
params.oaicompat_chat_syntax.thinking_forced_open = json_value(data, "thinking_forced_open", false);
|
params.oaicompat_chat_syntax.thinking_forced_open = json_value(data, "thinking_forced_open", false);
|
||||||
params.oaicompat_chat_syntax.parse_tool_calls = json_value(data, "parse_tool_calls", false);
|
params.oaicompat_chat_syntax.parse_tool_calls = json_value(data, "parse_tool_calls", false);
|
||||||
}
|
}
|
||||||
|
|
|
@ -209,6 +209,7 @@ export const AppContextProvider = ({
|
||||||
messages,
|
messages,
|
||||||
stream: true,
|
stream: true,
|
||||||
cache_prompt: true,
|
cache_prompt: true,
|
||||||
|
reasoning_format: 'none',
|
||||||
samplers: config.samplers,
|
samplers: config.samplers,
|
||||||
temperature: config.temperature,
|
temperature: config.temperature,
|
||||||
dynatemp_range: config.dynatemp_range,
|
dynatemp_range: config.dynatemp_range,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue