mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-04-28 03:39:48 +00:00
update: add cache class and ascend ln mlp op for qwen3 adapt npu (#1708)
Some checks are pending
Book-CI / test (push) Waiting to run
Book-CI / test-1 (push) Waiting to run
Book-CI / test-2 (push) Waiting to run
Deploy / deploy (macos-latest) (push) Waiting to run
Deploy / deploy (ubuntu-latest) (push) Waiting to run
Deploy / deploy (windows-latest) (push) Waiting to run
Some checks are pending
Book-CI / test (push) Waiting to run
Book-CI / test-1 (push) Waiting to run
Book-CI / test-2 (push) Waiting to run
Deploy / deploy (macos-latest) (push) Waiting to run
Deploy / deploy (ubuntu-latest) (push) Waiting to run
Deploy / deploy (windows-latest) (push) Waiting to run
This commit is contained in:
parent
cea490a326
commit
1e69563363
4 changed files with 562 additions and 1 deletions
|
|
@ -523,4 +523,392 @@ class KGQACache(nn.Module):
|
|||
return self.k_caches[layer_idx]
|
||||
|
||||
def get_v_cache(self, layer_idx):
|
||||
return self.v_caches[layer_idx]
|
||||
return self.v_caches[layer_idx]
|
||||
|
||||
|
||||
class KVC2Qwen3Cache(nn.Module):
|
||||
|
||||
def __init__(self, config, max_batch_size, page_size=256,
|
||||
dtype=torch.bfloat16, device=None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.max_batch_size = max_batch_size
|
||||
self.page_size = page_size
|
||||
self.dtype = dtype
|
||||
self.device = device if device else torch.device("npu:0")
|
||||
|
||||
self.num_layers = config.num_hidden_layers
|
||||
self.num_kv_heads = config.num_key_value_heads
|
||||
self.head_dim = config.head_dim
|
||||
|
||||
self.k_caches = []
|
||||
self.v_caches = []
|
||||
|
||||
# 环境变量控制日志/调试
|
||||
self.debug_load = os.environ.get("KTRANS_DEBUG_KV_LOAD", "0") == "1"
|
||||
self.debug_update = os.environ.get("KTRANS_DEBUG_KV_UPDATE", "0") == "1"
|
||||
|
||||
# ------------------------- 调试工具(按 env 开关) -------------------------
|
||||
|
||||
def _debug_dump_page_layout(
|
||||
self,
|
||||
page_idx: torch.Tensor, # [B, Q] 或 [N]
|
||||
page_offset: torch.Tensor, # 同上
|
||||
bsz: int,
|
||||
q_len: int,
|
||||
layer_idx: int,
|
||||
):
|
||||
# graph capture 时跳过
|
||||
try:
|
||||
if hasattr(torch.npu, "is_current_stream_capturing") and torch.npu.is_current_stream_capturing():
|
||||
return
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
page_size = self.page_size
|
||||
|
||||
pi = page_idx.detach().to("cpu", torch.long).reshape(bsz, q_len)
|
||||
po = page_offset.detach().to("cpu", torch.long).reshape(bsz, q_len)
|
||||
|
||||
for b in range(bsz):
|
||||
row_pi = pi[b] # [Q]
|
||||
row_po = po[b] # [Q]
|
||||
|
||||
unique_pages = sorted(set(row_pi.tolist()))
|
||||
print(f"[DEBUG-LAYOUT] layer={layer_idx}, batch={b}: pages used = {unique_pages}", flush=True)
|
||||
|
||||
for p in unique_pages:
|
||||
mask = (row_pi == p)
|
||||
if not mask.any():
|
||||
continue
|
||||
offsets = row_po[mask]
|
||||
|
||||
cnt = int(mask.sum())
|
||||
min_off = int(offsets.min())
|
||||
max_off = int(offsets.max())
|
||||
|
||||
print(
|
||||
f" page {p}: count={cnt}, "
|
||||
f"offset[min,max]=[{min_off}, {max_off}]",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
# Q 比较小的时候,画一个 ASCII 条形图(仅看前 64 个位置)
|
||||
if page_size <= 64 and q_len <= 64:
|
||||
bar = ['.'] * page_size
|
||||
for off in offsets.tolist():
|
||||
if 0 <= off < page_size:
|
||||
bar[off] = 'X'
|
||||
print(f" layout: [{''.join(bar)}]", flush=True)
|
||||
|
||||
def _debug_check_page_mapping(
|
||||
self,
|
||||
page_idx: torch.Tensor,
|
||||
page_offset: torch.Tensor,
|
||||
bsz: int,
|
||||
q_len: int,
|
||||
layer_idx: int,
|
||||
):
|
||||
"""
|
||||
看 (page_idx, page_offset) 展开后的 global_pos 是否合理。
|
||||
global_pos = page_idx * page_size + page_offset
|
||||
"""
|
||||
page_size = self.page_size
|
||||
|
||||
pi = page_idx.to(torch.long).reshape(-1)
|
||||
po = page_offset.to(torch.long).reshape(-1)
|
||||
N = bsz * q_len
|
||||
|
||||
if pi.numel() != N:
|
||||
print(
|
||||
f"[DEBUG-PAGE][WARN] layer={layer_idx}: "
|
||||
f"numel(page_idx)={pi.numel()} != bsz*q_len={N}",
|
||||
flush=True,
|
||||
)
|
||||
return
|
||||
|
||||
gpos = pi * page_size + po # [N]
|
||||
|
||||
n_show = gpos.numel()
|
||||
print(
|
||||
f"[DEBUG-PAGE] layer={layer_idx}, first {n_show} global_pos="
|
||||
f"{gpos[:n_show].tolist()}",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
if gpos.numel() > 1:
|
||||
diffs = gpos[1:] - gpos[:-1]
|
||||
print(
|
||||
f"[DEBUG-PAGE] layer={layer_idx}, "
|
||||
f"global_pos diff min={int(diffs.min())}, max={int(diffs.max())}",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
if not torch.all(diffs >= 0):
|
||||
print(
|
||||
f"[DEBUG-PAGE][WARN] layer={layer_idx}: "
|
||||
f"global_pos not non-decreasing!",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
def _debug_verify_k_roundtrip(
|
||||
self,
|
||||
flat_k: torch.Tensor, # [N, KvH, Dh]
|
||||
layer_idx: int,
|
||||
page_idx: torch.Tensor,
|
||||
page_offset: torch.Tensor,
|
||||
):
|
||||
k_out = self.k_caches[layer_idx] # [num_pages, page_size, KvH, Dh]
|
||||
|
||||
pi = page_idx.to(torch.long).reshape(-1)
|
||||
po = page_offset.to(torch.long).reshape(-1)
|
||||
|
||||
fetched = k_out[pi, po] # [N, KvH, Dh]
|
||||
|
||||
if fetched.shape != flat_k.shape:
|
||||
print(
|
||||
f"[DEBUG-KV][WARN] layer={layer_idx}: "
|
||||
f"fetched.shape={tuple(fetched.shape)} != flat_k.shape={tuple(flat_k.shape)}",
|
||||
flush=True,
|
||||
)
|
||||
return
|
||||
|
||||
diff = (fetched - flat_k).abs()
|
||||
max_diff = diff.max().item()
|
||||
mean_diff = diff.mean().item()
|
||||
|
||||
print(
|
||||
f"[DEBUG-KV] layer={layer_idx}: K roundtrip max_abs_diff={max_diff}, "
|
||||
f"mean_abs_diff={mean_diff}",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
# ------------------------- 绑定到底层 kvc2 pool -------------------------
|
||||
|
||||
def load(self, inference_context):
|
||||
from ktransformers.util.utils import get_current_device
|
||||
dev = get_current_device()
|
||||
|
||||
self.k_caches = []
|
||||
self.v_caches = []
|
||||
|
||||
rank = (
|
||||
torch.distributed.get_rank()
|
||||
if (torch.distributed.is_available() and torch.distributed.is_initialized())
|
||||
else 0
|
||||
)
|
||||
|
||||
for i in range(self.num_layers):
|
||||
k_buf = inference_context.k_cache[rank][i].to(dev).to(self.dtype)
|
||||
v_buf = inference_context.v_cache[rank][i].to(dev).to(self.dtype)
|
||||
|
||||
torch._dynamo.mark_static_address(k_buf)
|
||||
torch._dynamo.mark_static_address(v_buf)
|
||||
|
||||
self.k_caches.append(k_buf)
|
||||
self.v_caches.append(v_buf)
|
||||
|
||||
if self.debug_load:
|
||||
print(
|
||||
f"[KV-CACHE-LOAD] layer={i}, "
|
||||
f"k_cache shape={tuple(k_buf.shape)}, "
|
||||
f"v_cache shape={tuple(v_buf.shape)}, dtype={k_buf.dtype}",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
# num_pages * page_size
|
||||
self.max_cache_len = self.k_caches[0].shape[0] * self.k_caches[0].shape[1]
|
||||
|
||||
# ------------------------- 写 KV -------------------------
|
||||
@torch.no_grad()
|
||||
def update(
|
||||
self,
|
||||
key_states: torch.Tensor,
|
||||
value_states: torch.Tensor,
|
||||
layer_idx: int,
|
||||
cache_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
if cache_kwargs is None:
|
||||
raise ValueError("[KVC2Qwen3Cache] cache_kwargs must contain page_idx & page_offset")
|
||||
|
||||
page_idx: Optional[torch.Tensor] = cache_kwargs.get("page_idx", None)
|
||||
page_offset: Optional[torch.Tensor] = cache_kwargs.get("page_offset", None)
|
||||
|
||||
if page_idx is None or page_offset is None:
|
||||
raise ValueError("[KVC2Qwen3Cache] page_idx & page_offset are required in cache_kwargs")
|
||||
|
||||
k_out = self.k_caches[layer_idx]
|
||||
v_out = self.v_caches[layer_idx]
|
||||
|
||||
if self.debug_update:
|
||||
print(
|
||||
"[KV-UPDATE]",
|
||||
f"layer={layer_idx}, key={tuple(key_states.shape)}, value={tuple(value_states.shape)}, "
|
||||
f"page_idx shape={tuple(page_idx.shape)}, page_offset shape={tuple(page_offset.shape)}, "
|
||||
f"k_out shape={tuple(k_out.shape)}, k_out.dtype={k_out.dtype}",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
# -------- 1) 修正维度顺序:[B, KvH, Q, D] -> [B, Q, KvH, D] --------
|
||||
if key_states.dim() == 4 and key_states.shape[1] == self.num_kv_heads:
|
||||
if self.debug_update:
|
||||
print(
|
||||
"[KV-UPDATE] detected layout [B, KvH, Q, D], transpose -> [B, Q, KvH, D]",
|
||||
flush=True,
|
||||
)
|
||||
key_states = key_states.transpose(1, 2).contiguous()
|
||||
value_states = value_states.transpose(1, 2).contiguous()
|
||||
|
||||
if key_states.shape != value_states.shape:
|
||||
raise ValueError(
|
||||
f"[KVC2Qwen3Cache] key_states.shape {key_states.shape} "
|
||||
f"!= value_states.shape {value_states.shape}"
|
||||
)
|
||||
|
||||
if key_states.dim() != 4:
|
||||
raise ValueError(
|
||||
f"[KVC2Qwen3Cache] expect key_states dim=4, got {key_states.dim()} "
|
||||
f"(shape={key_states.shape})"
|
||||
)
|
||||
|
||||
bsz, q_len, kv_heads, head_dim = key_states.shape
|
||||
|
||||
# if self.debug_update:
|
||||
# print(
|
||||
# "[KV-UPDATE] after layout fix:",
|
||||
# f"bsz={bsz}, q_len={q_len}, kv_heads={kv_heads}, head_dim={head_dim}",
|
||||
# flush=True,
|
||||
# )
|
||||
|
||||
if kv_heads != self.num_kv_heads or head_dim != self.head_dim:
|
||||
raise ValueError(
|
||||
f"[KVC2Qwen3Cache] KV shape mismatch: "
|
||||
f"got num_kv_heads={kv_heads}, head_dim={head_dim}, "
|
||||
f"expected num_kv_heads={self.num_kv_heads}, head_dim={self.head_dim}"
|
||||
)
|
||||
|
||||
# ================== DEBUG:检查 page 映射 ==================
|
||||
if os.environ.get("KTRANS_DEBUG_PAGE", "0") == "1":
|
||||
try:
|
||||
if not torch.npu.is_current_stream_capturing():
|
||||
self._debug_check_page_mapping(
|
||||
page_idx,
|
||||
page_offset,
|
||||
bsz=bsz,
|
||||
q_len=q_len,
|
||||
layer_idx=layer_idx,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if os.environ.get("KTRANS_DEBUG_LAYOUT", "0") == "1":
|
||||
self._debug_dump_page_layout(
|
||||
page_idx,
|
||||
page_offset,
|
||||
bsz=bsz,
|
||||
q_len=q_len,
|
||||
layer_idx=layer_idx,
|
||||
)
|
||||
|
||||
# -------- 2) flatten page_idx / page_offset 为一维 --------
|
||||
page_idx = page_idx.reshape(-1)
|
||||
page_offset = page_offset.reshape(-1)
|
||||
|
||||
# -------- 3) flatten KV,并强制 dtype 与 cache 对齐 --------
|
||||
val_dtype = k_out.dtype
|
||||
flat_k = key_states.to(val_dtype).reshape(-1, kv_heads, head_dim)
|
||||
flat_v = value_states.to(val_dtype).reshape(-1, kv_heads, head_dim)
|
||||
|
||||
# if self.debug_update:
|
||||
# print(
|
||||
# "[KV-UPDATE] flat_k.shape=",
|
||||
# tuple(flat_k.shape),
|
||||
# " flat_v.shape=",
|
||||
# tuple(flat_v.shape),
|
||||
# " flat_k.dtype=",
|
||||
# flat_k.dtype,
|
||||
# flush=True,
|
||||
# )
|
||||
|
||||
# -------- 4) 真正写入 K / V --------
|
||||
# k_out / v_out: [num_pages, page_size, num_kv_heads, head_dim]
|
||||
k_out[page_idx, page_offset] = flat_k
|
||||
v_out[page_idx, page_offset] = flat_v
|
||||
|
||||
# if self.debug_update:
|
||||
# print(f"[KV-UPDATE] write done for layer {layer_idx}", flush=True)
|
||||
|
||||
# ================== DEBUG:写入后从 cache 读出来对比 ==================
|
||||
if os.environ.get("KTRANS_DEBUG_KV", "0") == "1":
|
||||
try:
|
||||
if not torch.npu.is_current_stream_capturing():
|
||||
self._debug_verify_k_roundtrip(
|
||||
flat_k=flat_k,
|
||||
layer_idx=layer_idx,
|
||||
page_idx=page_idx,
|
||||
page_offset=page_offset,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return k_out, v_out
|
||||
|
||||
# ------------------------- get K/V -------------------------
|
||||
def get_k_cache(self, layer_idx):
|
||||
return self.k_caches[layer_idx]
|
||||
|
||||
def get_v_cache(self, layer_idx):
|
||||
return self.v_caches[layer_idx]
|
||||
|
||||
# ------------------------- page table 计算 -------------------------
|
||||
def get_page_table(
|
||||
self,
|
||||
mini_batch,
|
||||
bsz_tensors: torch.Tensor = None,
|
||||
is_prefill: bool = True,
|
||||
):
|
||||
if is_prefill:
|
||||
# prefill: merged positions => batched (B, T_chunk)
|
||||
q_lens = [int(mini_batch.p_q_len[idx]) for idx in range(mini_batch.prefill_batch)]
|
||||
if len(q_lens) == 0:
|
||||
return None, None
|
||||
|
||||
max_q_len = max(q_lens)
|
||||
|
||||
page_local_idx = -1 * torch.ones(
|
||||
mini_batch.prefill_batch,
|
||||
max_q_len,
|
||||
dtype=mini_batch.p_position_ids.dtype,
|
||||
device=mini_batch.p_position_ids.device,
|
||||
)
|
||||
page_offset = -1 * torch.ones_like(page_local_idx)
|
||||
|
||||
start_ids = 0
|
||||
for i in range(mini_batch.prefill_batch):
|
||||
cur_len = q_lens[i]
|
||||
pos = mini_batch.p_position_ids[start_ids:start_ids + cur_len] # global pos of this chunk
|
||||
|
||||
# local block + offset by page_size
|
||||
page_offset[i, 0:cur_len] = pos % self.page_size
|
||||
page_local_idx[i, 0:cur_len] = pos // self.page_size
|
||||
|
||||
# local block -> global page id via block_tables
|
||||
for j in range(cur_len):
|
||||
blk = page_local_idx[i, j]
|
||||
page_local_idx[i, j] = mini_batch.p_block_tables[i, blk]
|
||||
|
||||
start_ids += cur_len
|
||||
|
||||
page_idx = page_local_idx
|
||||
else:
|
||||
# decode: decode_batch = 当前 step 的 batch_size, 每条样本通常 1 个 token
|
||||
page_local_idx = mini_batch.d_position_ids // self.page_size
|
||||
page_offset = mini_batch.d_position_ids % self.page_size
|
||||
|
||||
for i in range(mini_batch.decode_batch):
|
||||
blk = page_local_idx[i]
|
||||
page_local_idx[i] = mini_batch.d_block_tables[i, blk]
|
||||
|
||||
page_idx = page_local_idx
|
||||
|
||||
return page_idx, page_offset
|
||||
|
|
|
|||
|
|
@ -1,9 +1,30 @@
|
|||
# coding=utf-8
|
||||
# Copyright (c) 2025. Huawei Technologies Co., Ltd. All rights reserved.
|
||||
# Copyright 2025 The ZhipuAI Inc. team and HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import re
|
||||
from typing import Optional, Union, Tuple
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from ktransformers.operators.base_operator import BaseInjectedModule
|
||||
from ktransformers.models.modeling_qwen3_moe import Qwen3MoeRMSNorm
|
||||
from ktransformers.util import utils
|
||||
from ktransformers.util.custom_loader import GGUFLoader
|
||||
|
||||
|
|
@ -36,3 +57,79 @@ class KDeepseekV3RMSNormW8A8(BaseInjectedModule):
|
|||
self.weight = None
|
||||
if self.bias is not None:
|
||||
self.bias = None
|
||||
|
||||
class KQwen3MoeRMSNormW8A8(BaseInjectedModule):
|
||||
def __init__(self,
|
||||
key: str,
|
||||
gguf_loader: GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module,
|
||||
prefill_device: str = "npu",
|
||||
generate_device: str = "npu",
|
||||
**kwargs):
|
||||
|
||||
super().__init__(key, gguf_loader, config, orig_module,
|
||||
prefill_device, generate_device, **kwargs)
|
||||
|
||||
self.hidden_size = orig_module.hidden_size
|
||||
self.variance_epsilon = orig_module.variance_epsilon
|
||||
self.weight = nn.Parameter(orig_module.weight.data.clone())
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
x = x.to(torch.float16)
|
||||
gamma = self.weight.to(torch.float16)
|
||||
|
||||
input_dtype = x.dtype
|
||||
out = torch_npu.npu_rms_norm(
|
||||
x,
|
||||
gamma,
|
||||
self.variance_epsilon
|
||||
)[0]
|
||||
|
||||
return out.to(input_dtype)
|
||||
|
||||
def load(self):
|
||||
device = utils.get_current_device()
|
||||
self.weight = self.gguf_loader.safetensor_loader.load_tensor(self.key + ".weight").to(device)
|
||||
|
||||
try:
|
||||
self.bias = (
|
||||
self.gguf_loader.safetensor_loader
|
||||
.load_tensor(self.key + ".bias")
|
||||
.to(device)
|
||||
)
|
||||
except KeyError:
|
||||
self.bias = None
|
||||
|
||||
def unload(self):
|
||||
self.weight = None
|
||||
self.bias = None
|
||||
|
||||
class KQwen3FinalRMSNormNPU(nn.Module):
|
||||
def __init__(self, orig_module: nn.Module):
|
||||
super().__init__()
|
||||
assert hasattr(orig_module, "weight"), "orig_module must have weight"
|
||||
self.variance_epsilon = getattr(orig_module, "variance_epsilon", 1e-6)
|
||||
|
||||
w = orig_module.weight.detach()
|
||||
if w.dtype not in (torch.float16, torch.bfloat16, torch.float32):
|
||||
w = w.to(torch.float16)
|
||||
else:
|
||||
if w.dtype == torch.float32:
|
||||
w = w.to(torch.float16)
|
||||
|
||||
self.weight = nn.Parameter(w)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
input_dtype = x.dtype
|
||||
x = x.contiguous()
|
||||
gamma = self.weight
|
||||
x_rms = x.to(dtype=gamma.dtype)
|
||||
|
||||
out = torch_npu.npu_rms_norm(
|
||||
x_rms,
|
||||
gamma,
|
||||
self.variance_epsilon
|
||||
)[0]
|
||||
|
||||
return out.to(input_dtype)
|
||||
|
|
@ -1,3 +1,19 @@
|
|||
# coding=utf-8
|
||||
# Copyright (c) 2025. Huawei Technologies Co., Ltd. All rights reserved.
|
||||
# Copyright 2025 The ZhipuAI Inc. team and HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from abc import abstractmethod
|
||||
|
||||
import torch
|
||||
|
|
@ -113,6 +129,8 @@ class KLinearTorchW8A8A2(KLinearW8A8):
|
|||
self.weight_offset = None
|
||||
|
||||
def forward(self, x: torch.Tensor, bsz_tensor) -> torch.Tensor:
|
||||
if x.dtype != self.weight.dtype:
|
||||
x = x.to(self.weight.dtype)
|
||||
return torch.matmul(x, self.weight)
|
||||
|
||||
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str | None = None):
|
||||
|
|
|
|||
|
|
@ -1,3 +1,19 @@
|
|||
# coding=utf-8
|
||||
# Copyright (c) 2025. Huawei Technologies Co., Ltd. All rights reserved.
|
||||
# Copyright 2025 The ZhipuAI Inc. team and HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
|
||||
|
|
@ -5,6 +21,7 @@ from ktransformers.util.ascend.ascend_utils import allredeuce_warpper
|
|||
from ktransformers.util.utils import CUR_DEVICE
|
||||
from ktransformers.operators.base_operator import BaseInjectedModule
|
||||
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3MLP
|
||||
from ktransformers.models.modeling_qwen3_moe import Qwen3MoeMLP
|
||||
|
||||
class KDeepseekV3MLPW8A8A2V1(BaseInjectedModule, DeepseekV3MLP):
|
||||
@allredeuce_warpper
|
||||
|
|
@ -71,3 +88,44 @@ class KDeepseekV3MLPW8A8A2V2(BaseInjectedModule, DeepseekV3MLP):
|
|||
)
|
||||
down_proj = down_proj.reshape(x.shape)
|
||||
return down_proj
|
||||
|
||||
class KQwen3MoeMLPW8A8A2(BaseInjectedModule, Qwen3MoeMLP):
|
||||
@allredeuce_warpper
|
||||
def forward(self, x, is_prefill=None, use_cuda_graph=False):
|
||||
original_dtype = x.dtype
|
||||
quant_out, dynamic_scale = torch_npu.npu_dynamic_quant(x)
|
||||
dynamic_scale = dynamic_scale.view(-1)
|
||||
quant_out = quant_out.view(-1, quant_out.shape[-1])
|
||||
|
||||
gate_x = torch_npu.npu_quant_matmul(
|
||||
quant_out,
|
||||
self.orig_module.gate_proj.weight,
|
||||
self.orig_module.gate_proj.weight_scale,
|
||||
pertoken_scale=dynamic_scale,
|
||||
bias=None,
|
||||
output_dtype=original_dtype,
|
||||
)
|
||||
up_x = torch_npu.npu_quant_matmul(
|
||||
quant_out,
|
||||
self.orig_module.up_proj.weight,
|
||||
self.orig_module.up_proj.weight_scale,
|
||||
pertoken_scale=dynamic_scale,
|
||||
bias=None,
|
||||
output_dtype=original_dtype,
|
||||
)
|
||||
|
||||
down_x = torch.nn.functional.silu(gate_x) * up_x
|
||||
|
||||
down_quant_out, down_dynamic_scale = torch_npu.npu_dynamic_quant(down_x)
|
||||
down_dynamic_scale = down_dynamic_scale.view(-1)
|
||||
|
||||
down_proj = torch_npu.npu_quant_matmul(
|
||||
down_quant_out,
|
||||
self.orig_module.down_proj.weight,
|
||||
self.orig_module.down_proj.weight_scale,
|
||||
pertoken_scale=down_dynamic_scale,
|
||||
bias=None,
|
||||
output_dtype=original_dtype,
|
||||
)
|
||||
down_proj = down_proj.reshape(x.shape)
|
||||
return down_proj
|
||||
Loading…
Add table
Add a link
Reference in a new issue