mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-13 18:39:48 +00:00
llama : add gpt-oss (#15091)
* oai moe * compat with new checkpoint * add attn sink impl * add rope scaling yarn * logits match with latest transformers code * wip chat template * rm trailing space * use ggml_scale_bias * rm redundant is_swa_all * convert interleaved gate_up * graph : fix activation function to match reference (#7) * vocab : handle o200k_harmony special tokens * ggml : add attention sinks support (#1) * llama : add attn sinks * ggml : add attn sinks * cuda : add attn sinks * vulkan : add support for sinks in softmax remove unnecessary return * ggml : add fused swiglu_oai op (#11) * ggml : add fused swiglu_oai op * Update ggml/src/ggml-cpu/ops.cpp Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * update CUDA impl * cont : metal impl * add vulkan impl * test-backend-ops : more test cases, clean up * llama : remove unfused impl * remove extra lines --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> --------- Co-authored-by: slaren <slarengh@gmail.com> * repack mxfp4 upon conversion * clean up a bit * enable thinking * add quick hack to render only some special tokens * fix bf16 conversion * remove vocab hack * webui ok * support chat parsing for gpt-oss * fix webui * direct mapping mxfp4, FINALLY * force using mxfp4 * properly use lazy tensor * ggml : add mxfp4 ggml : use e8m0 conversion instead of powf Co-authored-by: Diego Devesa <slarengh@gmail.com> change kvalues_mxfp4 table to match e2m1 (#6) metal : remove quantization for now (not used) cuda : fix disabled CUDA graphs due to ffn moe bias vulkan : add support for mxfp4 cont : add cm2 dequant * ggml : add ggml_add_id (#13) * ggml : add ggml_add_id * add cuda impl * llama : add weight support check for add_id * perf opt * add vulkan impl * rename cuda files * add metal impl * allow in-place ggml_add_id * llama : keep biases on CPU with --cpu-moe * llama : fix compile error ggml-ci * cuda : add fallback for __nv_cvt_e8m0_to_bf16raw ggml-ci * cleanup ggml-ci * sycl : fix supports_op for MXFP4 ggml-ci * fix Unknown reasoning format * ggml-cpu : fix AVX build ggml-ci * fix hip build ggml-ci * cuda : add mxfp4 dequantization support for cuBLAS ggml-ci * ggml-cpu : fix mxfp4 fallback definitions for some architectures ggml-ci * cuda : fix version required for __nv_cvt_e8m0_to_bf16raw --------- Co-authored-by: Xuan Son Nguyen <son@huggingface.co> Co-authored-by: slaren <slarengh@gmail.com>
This commit is contained in:
parent
f324a3b715
commit
fd1234cb46
83 changed files with 2942 additions and 227 deletions
|
@ -7950,6 +7950,119 @@ class SmolLM3Model(LlamaModel):
|
|||
self.gguf_writer.add_chat_template(chat_template)
|
||||
|
||||
|
||||
@ModelBase.register("GptOssForCausalLM")
|
||||
class GptOssModel(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.GPT_OSS
|
||||
|
||||
def transform_nibble_layout(self, tensor):
|
||||
assert tensor.dtype == torch.uint8
|
||||
assert tensor.shape[-1] == 16
|
||||
# swap nibbles
|
||||
t_lo = tensor & 0x0F
|
||||
t_hi = tensor & 0xF0
|
||||
t_swapped = (t_lo << 4) | (t_hi >> 4)
|
||||
tensor = t_swapped
|
||||
# transform aaaa...bbbb... to abababab...
|
||||
blk_a, blk_b = tensor.chunk(2, dim=-1)
|
||||
# get a_
|
||||
blk_a0 = (blk_a & 0xF0).view(-1, 1)
|
||||
blk_a1 = (blk_a << 4).view(-1, 1)
|
||||
blk_a = torch.stack((blk_a0, blk_a1), dim=2).view(tensor.shape)
|
||||
# get _b
|
||||
blk_b0 = (blk_b >> 4).view(-1, 1)
|
||||
blk_b1 = (blk_b & 0x0F).view(-1, 1)
|
||||
blk_b = torch.stack((blk_b0, blk_b1), dim=2).view(tensor.shape)
|
||||
# swap once more
|
||||
out = blk_a | blk_b
|
||||
out_h = out & 0xF0
|
||||
out_l = out & 0x0F
|
||||
out = (out_h >> 4) | (out_l << 4)
|
||||
return out
|
||||
|
||||
def repack_mxfp4(self, new_name: str, blocks: Tensor, scales: Tensor):
|
||||
assert blocks.dtype == torch.uint8
|
||||
assert scales.dtype == torch.uint8
|
||||
scales = scales.unsqueeze(-1)
|
||||
assert len(blocks.shape) == 4
|
||||
assert len(scales.shape) == 4
|
||||
blocks = self.transform_nibble_layout(blocks)
|
||||
new_data = torch.concat((scales, blocks), dim=-1)
|
||||
new_shape = [new_data.shape[0], new_data.shape[1], new_data.shape[2] * 32]
|
||||
logger.info(f"Repacked {new_name} with shape {new_shape} and quantization MXFP4")
|
||||
# flatten last dim
|
||||
new_data = new_data.view(new_data.shape[0], new_data.shape[1], new_data.shape[2] * new_data.shape[3])
|
||||
new_data = new_data.numpy()
|
||||
self.gguf_writer.add_tensor(new_name, new_data, raw_dtype=gguf.GGMLQuantizationType.MXFP4)
|
||||
|
||||
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
|
||||
blocks0: Tensor = torch.zeros(1)
|
||||
blocks1: Tensor = torch.zeros(1)
|
||||
found_mxfp4_tensors = False
|
||||
# we assume that tensors are loaded in the correct order
|
||||
for name, data_torch in self.get_tensors():
|
||||
if "mlp.experts.down_proj_blocks" in name:
|
||||
blocks0 = data_torch
|
||||
elif "mlp.experts.down_proj_scales" in name:
|
||||
new_name = self.map_tensor_name(name.replace("_scales", ".weight"))
|
||||
self.repack_mxfp4(new_name, blocks0, data_torch)
|
||||
found_mxfp4_tensors = True
|
||||
elif "mlp.experts.gate_up_proj_blocks" in name:
|
||||
blocks0, blocks1 = data_torch[:, ::2, :, :], data_torch[:, 1::2, :, :]
|
||||
elif "mlp.experts.gate_up_proj_scales" in name:
|
||||
scales0, scales1 = data_torch[:, ::2, :], data_torch[:, 1::2, :]
|
||||
new_name_gate = self.map_tensor_name(name.replace("gate_up_proj_scales", "gate_proj.weight"))
|
||||
new_name_up = self.map_tensor_name(name.replace("gate_up_proj_scales", "up_proj.weight"))
|
||||
self.repack_mxfp4(new_name_gate, blocks0, scales0)
|
||||
self.repack_mxfp4(new_name_up, blocks1, scales1)
|
||||
found_mxfp4_tensors = True
|
||||
if not found_mxfp4_tensors:
|
||||
raise ValueError("No MXFP4 tensors found in the model. Please make sure you are using MXFP4 model.")
|
||||
return []
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
del bid # unused
|
||||
|
||||
if "sinks" in name:
|
||||
name += ".weight"
|
||||
|
||||
# correct naming for down_proj
|
||||
if "down_proj" in name:
|
||||
if name.endswith("_bias"):
|
||||
name = name.replace("down_proj_bias", "down_proj.bias")
|
||||
else:
|
||||
return []
|
||||
|
||||
# split the gate_up into gate and up
|
||||
if "gate_up_proj" in name:
|
||||
if name.endswith("_bias"):
|
||||
name_up = name.replace("gate_up_proj_bias", "up_proj.bias")
|
||||
name_gate = name.replace("gate_up_proj_bias", "gate_proj.bias")
|
||||
gate_proj_bias, up_proj_bias = data_torch[..., ::2], data_torch[..., 1::2]
|
||||
return [
|
||||
(self.map_tensor_name(name_gate), gate_proj_bias),
|
||||
(self.map_tensor_name(name_up), up_proj_bias)
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
def set_vocab(self):
|
||||
self._set_vocab_gpt2()
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
self.gguf_writer.add_sliding_window(self.hparams["sliding_window"])
|
||||
self.gguf_writer.add_expert_feed_forward_length(self.hparams["intermediate_size"])
|
||||
|
||||
rope_scaling = self.hparams.get("rope_scaling") or {}
|
||||
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type"))
|
||||
assert rope_type == "yarn", f"GPT-OSS only supports yarn rope scaling, got {rope_type}"
|
||||
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
|
||||
self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
|
||||
self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling.get("original_max_position_embeddings", 4096))
|
||||
|
||||
|
||||
@ModelBase.register("Lfm2ForCausalLM")
|
||||
@ModelBase.register("LFM2ForCausalLM")
|
||||
class LFM2Model(TextModel):
|
||||
|
@ -8089,6 +8202,7 @@ class LazyTorchTensor(gguf.LazyBase):
|
|||
_dtype_map: dict[torch.dtype, type] = {
|
||||
torch.float16: np.float16,
|
||||
torch.float32: np.float32,
|
||||
torch.uint8: np.uint8,
|
||||
}
|
||||
|
||||
# used for safetensors slices
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue