support qwen3, dont speak human language

This commit is contained in:
djw 2025-04-28 08:44:47 +00:00
parent f3d842a0ca
commit 3f9bbf1181
30 changed files with 3696 additions and 290 deletions

View file

@ -411,4 +411,30 @@ class RotaryEmbeddingV4(BaseInjectedModule):
self.inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
# self.register_buffer("inv_freq", inv_freq, persistent=False)
# For BC we register cos and sin cached
self.max_seq_len_cached = max_position_embeddings
self.max_seq_len_cached = max_position_embeddings
class KQwen3MoeRotaryEmbedding(BaseInjectedModule, DeepseekV2RotaryEmbedding):
def __init__(
self,
key: str,
gguf_loader: GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module,
# device: str = "cuda",
generate_device: str = "cuda",
prefill_device: str = "cuda",
**kwargs,
):
BaseInjectedModule.__init__(
self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs
)
self.orig_module.__init__(
config,
)
self.generate_device = generate_device
self.prefill_device = prefill_device
def load(self):
self.orig_module.__init__(
self.orig_module.config
)