llama : add gpt-oss (#15091)

* oai moe

* compat with new checkpoint

* add attn sink impl

* add rope scaling yarn

* logits match with latest transformers code

* wip chat template

* rm trailing space

* use ggml_scale_bias

* rm redundant is_swa_all

* convert interleaved gate_up

* graph : fix activation function to match reference (#7)

* vocab : handle o200k_harmony special tokens

* ggml : add attention sinks support (#1)

* llama : add attn sinks

* ggml : add attn sinks

* cuda : add attn sinks

* vulkan : add support for sinks in softmax

remove unnecessary return

* ggml : add fused swiglu_oai op (#11)

* ggml : add fused swiglu_oai op

* Update ggml/src/ggml-cpu/ops.cpp

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* update CUDA impl

* cont : metal impl

* add vulkan impl

* test-backend-ops : more test cases, clean up

* llama : remove unfused impl

* remove extra lines

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

---------

Co-authored-by: slaren <slarengh@gmail.com>

* repack mxfp4 upon conversion

* clean up a bit

* enable thinking

* add quick hack to render only some special tokens

* fix bf16 conversion

* remove vocab hack

* webui ok

* support chat parsing for gpt-oss

* fix webui

* direct mapping mxfp4, FINALLY

* force using mxfp4

* properly use lazy tensor

* ggml : add mxfp4

ggml : use e8m0 conversion instead of powf

Co-authored-by: Diego Devesa <slarengh@gmail.com>

change kvalues_mxfp4 table to match e2m1 (#6)

metal : remove quantization for now (not used)

cuda : fix disabled CUDA graphs due to ffn moe bias

vulkan : add support for mxfp4

cont : add cm2 dequant

* ggml : add ggml_add_id (#13)

* ggml : add ggml_add_id

* add cuda impl

* llama : add weight support check for add_id

* perf opt

* add vulkan impl

* rename cuda files

* add metal impl

* allow in-place ggml_add_id

* llama : keep biases on CPU with --cpu-moe

* llama : fix compile error

ggml-ci

* cuda : add fallback for __nv_cvt_e8m0_to_bf16raw

ggml-ci

* cleanup

ggml-ci

* sycl : fix supports_op for MXFP4

ggml-ci

* fix Unknown reasoning format

* ggml-cpu : fix AVX build

ggml-ci

* fix hip build

ggml-ci

* cuda : add mxfp4 dequantization support for cuBLAS

ggml-ci

* ggml-cpu : fix mxfp4 fallback definitions for some architectures

ggml-ci

* cuda : fix version required for __nv_cvt_e8m0_to_bf16raw

---------

Co-authored-by: Xuan Son Nguyen <son@huggingface.co>
Co-authored-by: slaren <slarengh@gmail.com>
This commit is contained in:
Georgi Gerganov 2025-08-05 22:10:36 +03:00 committed by GitHub
parent f324a3b715
commit fd1234cb46
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
83 changed files with 2942 additions and 227 deletions

View file

@ -7950,6 +7950,119 @@ class SmolLM3Model(LlamaModel):
self.gguf_writer.add_chat_template(chat_template)
@ModelBase.register("GptOssForCausalLM")
class GptOssModel(TextModel):
model_arch = gguf.MODEL_ARCH.GPT_OSS
def transform_nibble_layout(self, tensor):
assert tensor.dtype == torch.uint8
assert tensor.shape[-1] == 16
# swap nibbles
t_lo = tensor & 0x0F
t_hi = tensor & 0xF0
t_swapped = (t_lo << 4) | (t_hi >> 4)
tensor = t_swapped
# transform aaaa...bbbb... to abababab...
blk_a, blk_b = tensor.chunk(2, dim=-1)
# get a_
blk_a0 = (blk_a & 0xF0).view(-1, 1)
blk_a1 = (blk_a << 4).view(-1, 1)
blk_a = torch.stack((blk_a0, blk_a1), dim=2).view(tensor.shape)
# get _b
blk_b0 = (blk_b >> 4).view(-1, 1)
blk_b1 = (blk_b & 0x0F).view(-1, 1)
blk_b = torch.stack((blk_b0, blk_b1), dim=2).view(tensor.shape)
# swap once more
out = blk_a | blk_b
out_h = out & 0xF0
out_l = out & 0x0F
out = (out_h >> 4) | (out_l << 4)
return out
def repack_mxfp4(self, new_name: str, blocks: Tensor, scales: Tensor):
assert blocks.dtype == torch.uint8
assert scales.dtype == torch.uint8
scales = scales.unsqueeze(-1)
assert len(blocks.shape) == 4
assert len(scales.shape) == 4
blocks = self.transform_nibble_layout(blocks)
new_data = torch.concat((scales, blocks), dim=-1)
new_shape = [new_data.shape[0], new_data.shape[1], new_data.shape[2] * 32]
logger.info(f"Repacked {new_name} with shape {new_shape} and quantization MXFP4")
# flatten last dim
new_data = new_data.view(new_data.shape[0], new_data.shape[1], new_data.shape[2] * new_data.shape[3])
new_data = new_data.numpy()
self.gguf_writer.add_tensor(new_name, new_data, raw_dtype=gguf.GGMLQuantizationType.MXFP4)
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
blocks0: Tensor = torch.zeros(1)
blocks1: Tensor = torch.zeros(1)
found_mxfp4_tensors = False
# we assume that tensors are loaded in the correct order
for name, data_torch in self.get_tensors():
if "mlp.experts.down_proj_blocks" in name:
blocks0 = data_torch
elif "mlp.experts.down_proj_scales" in name:
new_name = self.map_tensor_name(name.replace("_scales", ".weight"))
self.repack_mxfp4(new_name, blocks0, data_torch)
found_mxfp4_tensors = True
elif "mlp.experts.gate_up_proj_blocks" in name:
blocks0, blocks1 = data_torch[:, ::2, :, :], data_torch[:, 1::2, :, :]
elif "mlp.experts.gate_up_proj_scales" in name:
scales0, scales1 = data_torch[:, ::2, :], data_torch[:, 1::2, :]
new_name_gate = self.map_tensor_name(name.replace("gate_up_proj_scales", "gate_proj.weight"))
new_name_up = self.map_tensor_name(name.replace("gate_up_proj_scales", "up_proj.weight"))
self.repack_mxfp4(new_name_gate, blocks0, scales0)
self.repack_mxfp4(new_name_up, blocks1, scales1)
found_mxfp4_tensors = True
if not found_mxfp4_tensors:
raise ValueError("No MXFP4 tensors found in the model. Please make sure you are using MXFP4 model.")
return []
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused
if "sinks" in name:
name += ".weight"
# correct naming for down_proj
if "down_proj" in name:
if name.endswith("_bias"):
name = name.replace("down_proj_bias", "down_proj.bias")
else:
return []
# split the gate_up into gate and up
if "gate_up_proj" in name:
if name.endswith("_bias"):
name_up = name.replace("gate_up_proj_bias", "up_proj.bias")
name_gate = name.replace("gate_up_proj_bias", "gate_proj.bias")
gate_proj_bias, up_proj_bias = data_torch[..., ::2], data_torch[..., 1::2]
return [
(self.map_tensor_name(name_gate), gate_proj_bias),
(self.map_tensor_name(name_up), up_proj_bias)
]
else:
return []
return [(self.map_tensor_name(name), data_torch)]
def set_vocab(self):
self._set_vocab_gpt2()
def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_sliding_window(self.hparams["sliding_window"])
self.gguf_writer.add_expert_feed_forward_length(self.hparams["intermediate_size"])
rope_scaling = self.hparams.get("rope_scaling") or {}
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type"))
assert rope_type == "yarn", f"GPT-OSS only supports yarn rope scaling, got {rope_type}"
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling.get("original_max_position_embeddings", 4096))
@ModelBase.register("Lfm2ForCausalLM")
@ModelBase.register("LFM2ForCausalLM")
class LFM2Model(TextModel):
@ -8089,6 +8202,7 @@ class LazyTorchTensor(gguf.LazyBase):
_dtype_map: dict[torch.dtype, type] = {
torch.float16: np.float16,
torch.float32: np.float32,
torch.uint8: np.uint8,
}
# used for safetensors slices