[feat]: add mistral moe loader compatibility (#1873)
Some checks failed
Book-CI / test-1 (push) Has been cancelled
Book-CI / test-2 (push) Has been cancelled
Deploy / deploy (macos-latest) (push) Has been cancelled
Book-CI / test (push) Has been cancelled
Deploy / deploy (ubuntu-latest) (push) Has been cancelled
Deploy / deploy (windows-latest) (push) Has been cancelled

Co-authored-by: chenht2022 <chenht2022@users.noreply.github.com>
This commit is contained in:
Chen Hongtao 2026-02-28 17:50:23 +08:00 committed by GitHub
parent 19887e4363
commit 9e69fccb02
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 73 additions and 19 deletions

View file

@ -448,6 +448,10 @@ class NativeMoEWrapper(BaseMoEWrapper):
self.down_scales = [t.to(torch.float32).contiguous() for t in weights["down_scale"]]
assert self.gate_scales[0].dtype == torch.float32, "Expected float32 scales for FP8"
elif self.method == "FP8_PERCHANNEL":
if self.gate_scales[0].dtype != torch.float32:
self.gate_scales = [t.to(torch.float32).contiguous() for t in weights["gate_scale"]]
self.up_scales = [t.to(torch.float32).contiguous() for t in weights["up_scale"]]
self.down_scales = [t.to(torch.float32).contiguous() for t in weights["down_scale"]]
assert self.gate_scales[0].dtype == torch.float32, "Expected float32 scales for FP8_PERCHANNEL"
t2 = time.time()