update: add cache class and ascend ln mlp op for qwen3 adapt npu (#1708)
Some checks are pending
Book-CI / test (push) Waiting to run
Book-CI / test-1 (push) Waiting to run
Book-CI / test-2 (push) Waiting to run
Deploy / deploy (macos-latest) (push) Waiting to run
Deploy / deploy (ubuntu-latest) (push) Waiting to run
Deploy / deploy (windows-latest) (push) Waiting to run

This commit is contained in:
Shaoxu Cheng 2025-12-11 17:08:35 +08:00 committed by GitHub
parent cea490a326
commit 1e69563363
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 562 additions and 1 deletions

View file

@ -523,4 +523,392 @@ class KGQACache(nn.Module):
return self.k_caches[layer_idx]
def get_v_cache(self, layer_idx):
return self.v_caches[layer_idx]
return self.v_caches[layer_idx]
class KVC2Qwen3Cache(nn.Module):
def __init__(self, config, max_batch_size, page_size=256,
dtype=torch.bfloat16, device=None):
super().__init__()
self.config = config
self.max_batch_size = max_batch_size
self.page_size = page_size
self.dtype = dtype
self.device = device if device else torch.device("npu:0")
self.num_layers = config.num_hidden_layers
self.num_kv_heads = config.num_key_value_heads
self.head_dim = config.head_dim
self.k_caches = []
self.v_caches = []
# 环境变量控制日志/调试
self.debug_load = os.environ.get("KTRANS_DEBUG_KV_LOAD", "0") == "1"
self.debug_update = os.environ.get("KTRANS_DEBUG_KV_UPDATE", "0") == "1"
# ------------------------- 调试工具(按 env 开关) -------------------------
def _debug_dump_page_layout(
self,
page_idx: torch.Tensor, # [B, Q] 或 [N]
page_offset: torch.Tensor, # 同上
bsz: int,
q_len: int,
layer_idx: int,
):
# graph capture 时跳过
try:
if hasattr(torch.npu, "is_current_stream_capturing") and torch.npu.is_current_stream_capturing():
return
except Exception:
pass
page_size = self.page_size
pi = page_idx.detach().to("cpu", torch.long).reshape(bsz, q_len)
po = page_offset.detach().to("cpu", torch.long).reshape(bsz, q_len)
for b in range(bsz):
row_pi = pi[b] # [Q]
row_po = po[b] # [Q]
unique_pages = sorted(set(row_pi.tolist()))
print(f"[DEBUG-LAYOUT] layer={layer_idx}, batch={b}: pages used = {unique_pages}", flush=True)
for p in unique_pages:
mask = (row_pi == p)
if not mask.any():
continue
offsets = row_po[mask]
cnt = int(mask.sum())
min_off = int(offsets.min())
max_off = int(offsets.max())
print(
f" page {p}: count={cnt}, "
f"offset[min,max]=[{min_off}, {max_off}]",
flush=True,
)
# Q 比较小的时候,画一个 ASCII 条形图(仅看前 64 个位置)
if page_size <= 64 and q_len <= 64:
bar = ['.'] * page_size
for off in offsets.tolist():
if 0 <= off < page_size:
bar[off] = 'X'
print(f" layout: [{''.join(bar)}]", flush=True)
def _debug_check_page_mapping(
self,
page_idx: torch.Tensor,
page_offset: torch.Tensor,
bsz: int,
q_len: int,
layer_idx: int,
):
"""
(page_idx, page_offset) 展开后的 global_pos 是否合理
global_pos = page_idx * page_size + page_offset
"""
page_size = self.page_size
pi = page_idx.to(torch.long).reshape(-1)
po = page_offset.to(torch.long).reshape(-1)
N = bsz * q_len
if pi.numel() != N:
print(
f"[DEBUG-PAGE][WARN] layer={layer_idx}: "
f"numel(page_idx)={pi.numel()} != bsz*q_len={N}",
flush=True,
)
return
gpos = pi * page_size + po # [N]
n_show = gpos.numel()
print(
f"[DEBUG-PAGE] layer={layer_idx}, first {n_show} global_pos="
f"{gpos[:n_show].tolist()}",
flush=True,
)
if gpos.numel() > 1:
diffs = gpos[1:] - gpos[:-1]
print(
f"[DEBUG-PAGE] layer={layer_idx}, "
f"global_pos diff min={int(diffs.min())}, max={int(diffs.max())}",
flush=True,
)
if not torch.all(diffs >= 0):
print(
f"[DEBUG-PAGE][WARN] layer={layer_idx}: "
f"global_pos not non-decreasing!",
flush=True,
)
def _debug_verify_k_roundtrip(
self,
flat_k: torch.Tensor, # [N, KvH, Dh]
layer_idx: int,
page_idx: torch.Tensor,
page_offset: torch.Tensor,
):
k_out = self.k_caches[layer_idx] # [num_pages, page_size, KvH, Dh]
pi = page_idx.to(torch.long).reshape(-1)
po = page_offset.to(torch.long).reshape(-1)
fetched = k_out[pi, po] # [N, KvH, Dh]
if fetched.shape != flat_k.shape:
print(
f"[DEBUG-KV][WARN] layer={layer_idx}: "
f"fetched.shape={tuple(fetched.shape)} != flat_k.shape={tuple(flat_k.shape)}",
flush=True,
)
return
diff = (fetched - flat_k).abs()
max_diff = diff.max().item()
mean_diff = diff.mean().item()
print(
f"[DEBUG-KV] layer={layer_idx}: K roundtrip max_abs_diff={max_diff}, "
f"mean_abs_diff={mean_diff}",
flush=True,
)
# ------------------------- 绑定到底层 kvc2 pool -------------------------
def load(self, inference_context):
from ktransformers.util.utils import get_current_device
dev = get_current_device()
self.k_caches = []
self.v_caches = []
rank = (
torch.distributed.get_rank()
if (torch.distributed.is_available() and torch.distributed.is_initialized())
else 0
)
for i in range(self.num_layers):
k_buf = inference_context.k_cache[rank][i].to(dev).to(self.dtype)
v_buf = inference_context.v_cache[rank][i].to(dev).to(self.dtype)
torch._dynamo.mark_static_address(k_buf)
torch._dynamo.mark_static_address(v_buf)
self.k_caches.append(k_buf)
self.v_caches.append(v_buf)
if self.debug_load:
print(
f"[KV-CACHE-LOAD] layer={i}, "
f"k_cache shape={tuple(k_buf.shape)}, "
f"v_cache shape={tuple(v_buf.shape)}, dtype={k_buf.dtype}",
flush=True,
)
# num_pages * page_size
self.max_cache_len = self.k_caches[0].shape[0] * self.k_caches[0].shape[1]
# ------------------------- 写 KV -------------------------
@torch.no_grad()
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
):
if cache_kwargs is None:
raise ValueError("[KVC2Qwen3Cache] cache_kwargs must contain page_idx & page_offset")
page_idx: Optional[torch.Tensor] = cache_kwargs.get("page_idx", None)
page_offset: Optional[torch.Tensor] = cache_kwargs.get("page_offset", None)
if page_idx is None or page_offset is None:
raise ValueError("[KVC2Qwen3Cache] page_idx & page_offset are required in cache_kwargs")
k_out = self.k_caches[layer_idx]
v_out = self.v_caches[layer_idx]
if self.debug_update:
print(
"[KV-UPDATE]",
f"layer={layer_idx}, key={tuple(key_states.shape)}, value={tuple(value_states.shape)}, "
f"page_idx shape={tuple(page_idx.shape)}, page_offset shape={tuple(page_offset.shape)}, "
f"k_out shape={tuple(k_out.shape)}, k_out.dtype={k_out.dtype}",
flush=True,
)
# -------- 1) 修正维度顺序:[B, KvH, Q, D] -> [B, Q, KvH, D] --------
if key_states.dim() == 4 and key_states.shape[1] == self.num_kv_heads:
if self.debug_update:
print(
"[KV-UPDATE] detected layout [B, KvH, Q, D], transpose -> [B, Q, KvH, D]",
flush=True,
)
key_states = key_states.transpose(1, 2).contiguous()
value_states = value_states.transpose(1, 2).contiguous()
if key_states.shape != value_states.shape:
raise ValueError(
f"[KVC2Qwen3Cache] key_states.shape {key_states.shape} "
f"!= value_states.shape {value_states.shape}"
)
if key_states.dim() != 4:
raise ValueError(
f"[KVC2Qwen3Cache] expect key_states dim=4, got {key_states.dim()} "
f"(shape={key_states.shape})"
)
bsz, q_len, kv_heads, head_dim = key_states.shape
# if self.debug_update:
# print(
# "[KV-UPDATE] after layout fix:",
# f"bsz={bsz}, q_len={q_len}, kv_heads={kv_heads}, head_dim={head_dim}",
# flush=True,
# )
if kv_heads != self.num_kv_heads or head_dim != self.head_dim:
raise ValueError(
f"[KVC2Qwen3Cache] KV shape mismatch: "
f"got num_kv_heads={kv_heads}, head_dim={head_dim}, "
f"expected num_kv_heads={self.num_kv_heads}, head_dim={self.head_dim}"
)
# ================== DEBUG检查 page 映射 ==================
if os.environ.get("KTRANS_DEBUG_PAGE", "0") == "1":
try:
if not torch.npu.is_current_stream_capturing():
self._debug_check_page_mapping(
page_idx,
page_offset,
bsz=bsz,
q_len=q_len,
layer_idx=layer_idx,
)
except Exception:
pass
if os.environ.get("KTRANS_DEBUG_LAYOUT", "0") == "1":
self._debug_dump_page_layout(
page_idx,
page_offset,
bsz=bsz,
q_len=q_len,
layer_idx=layer_idx,
)
# -------- 2) flatten page_idx / page_offset 为一维 --------
page_idx = page_idx.reshape(-1)
page_offset = page_offset.reshape(-1)
# -------- 3) flatten KV并强制 dtype 与 cache 对齐 --------
val_dtype = k_out.dtype
flat_k = key_states.to(val_dtype).reshape(-1, kv_heads, head_dim)
flat_v = value_states.to(val_dtype).reshape(-1, kv_heads, head_dim)
# if self.debug_update:
# print(
# "[KV-UPDATE] flat_k.shape=",
# tuple(flat_k.shape),
# " flat_v.shape=",
# tuple(flat_v.shape),
# " flat_k.dtype=",
# flat_k.dtype,
# flush=True,
# )
# -------- 4) 真正写入 K / V --------
# k_out / v_out: [num_pages, page_size, num_kv_heads, head_dim]
k_out[page_idx, page_offset] = flat_k
v_out[page_idx, page_offset] = flat_v
# if self.debug_update:
# print(f"[KV-UPDATE] write done for layer {layer_idx}", flush=True)
# ================== DEBUG写入后从 cache 读出来对比 ==================
if os.environ.get("KTRANS_DEBUG_KV", "0") == "1":
try:
if not torch.npu.is_current_stream_capturing():
self._debug_verify_k_roundtrip(
flat_k=flat_k,
layer_idx=layer_idx,
page_idx=page_idx,
page_offset=page_offset,
)
except Exception:
pass
return k_out, v_out
# ------------------------- get K/V -------------------------
def get_k_cache(self, layer_idx):
return self.k_caches[layer_idx]
def get_v_cache(self, layer_idx):
return self.v_caches[layer_idx]
# ------------------------- page table 计算 -------------------------
def get_page_table(
self,
mini_batch,
bsz_tensors: torch.Tensor = None,
is_prefill: bool = True,
):
if is_prefill:
# prefill: merged positions => batched (B, T_chunk)
q_lens = [int(mini_batch.p_q_len[idx]) for idx in range(mini_batch.prefill_batch)]
if len(q_lens) == 0:
return None, None
max_q_len = max(q_lens)
page_local_idx = -1 * torch.ones(
mini_batch.prefill_batch,
max_q_len,
dtype=mini_batch.p_position_ids.dtype,
device=mini_batch.p_position_ids.device,
)
page_offset = -1 * torch.ones_like(page_local_idx)
start_ids = 0
for i in range(mini_batch.prefill_batch):
cur_len = q_lens[i]
pos = mini_batch.p_position_ids[start_ids:start_ids + cur_len] # global pos of this chunk
# local block + offset by page_size
page_offset[i, 0:cur_len] = pos % self.page_size
page_local_idx[i, 0:cur_len] = pos // self.page_size
# local block -> global page id via block_tables
for j in range(cur_len):
blk = page_local_idx[i, j]
page_local_idx[i, j] = mini_batch.p_block_tables[i, blk]
start_ids += cur_len
page_idx = page_local_idx
else:
# decode: decode_batch = 当前 step 的 batch_size, 每条样本通常 1 个 token
page_local_idx = mini_batch.d_position_ids // self.page_size
page_offset = mini_batch.d_position_ids % self.page_size
for i in range(mini_batch.decode_batch):
blk = page_local_idx[i]
page_local_idx[i] = mini_batch.d_block_tables[i, blk]
page_idx = page_local_idx
return page_idx, page_offset

View file

@ -1,9 +1,30 @@
# coding=utf-8
# Copyright (c) 2025. Huawei Technologies Co., Ltd. All rights reserved.
# Copyright 2025 The ZhipuAI Inc. team and HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import re
from typing import Optional, Union, Tuple
import torch
import torch_npu
from torch import nn
from transformers import PretrainedConfig
from ktransformers.operators.base_operator import BaseInjectedModule
from ktransformers.models.modeling_qwen3_moe import Qwen3MoeRMSNorm
from ktransformers.util import utils
from ktransformers.util.custom_loader import GGUFLoader
@ -36,3 +57,79 @@ class KDeepseekV3RMSNormW8A8(BaseInjectedModule):
self.weight = None
if self.bias is not None:
self.bias = None
class KQwen3MoeRMSNormW8A8(BaseInjectedModule):
def __init__(self,
key: str,
gguf_loader: GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module,
prefill_device: str = "npu",
generate_device: str = "npu",
**kwargs):
super().__init__(key, gguf_loader, config, orig_module,
prefill_device, generate_device, **kwargs)
self.hidden_size = orig_module.hidden_size
self.variance_epsilon = orig_module.variance_epsilon
self.weight = nn.Parameter(orig_module.weight.data.clone())
def forward(self, x: torch.Tensor):
x = x.to(torch.float16)
gamma = self.weight.to(torch.float16)
input_dtype = x.dtype
out = torch_npu.npu_rms_norm(
x,
gamma,
self.variance_epsilon
)[0]
return out.to(input_dtype)
def load(self):
device = utils.get_current_device()
self.weight = self.gguf_loader.safetensor_loader.load_tensor(self.key + ".weight").to(device)
try:
self.bias = (
self.gguf_loader.safetensor_loader
.load_tensor(self.key + ".bias")
.to(device)
)
except KeyError:
self.bias = None
def unload(self):
self.weight = None
self.bias = None
class KQwen3FinalRMSNormNPU(nn.Module):
def __init__(self, orig_module: nn.Module):
super().__init__()
assert hasattr(orig_module, "weight"), "orig_module must have weight"
self.variance_epsilon = getattr(orig_module, "variance_epsilon", 1e-6)
w = orig_module.weight.detach()
if w.dtype not in (torch.float16, torch.bfloat16, torch.float32):
w = w.to(torch.float16)
else:
if w.dtype == torch.float32:
w = w.to(torch.float16)
self.weight = nn.Parameter(w)
def forward(self, x: torch.Tensor):
input_dtype = x.dtype
x = x.contiguous()
gamma = self.weight
x_rms = x.to(dtype=gamma.dtype)
out = torch_npu.npu_rms_norm(
x_rms,
gamma,
self.variance_epsilon
)[0]
return out.to(input_dtype)

View file

@ -1,3 +1,19 @@
# coding=utf-8
# Copyright (c) 2025. Huawei Technologies Co., Ltd. All rights reserved.
# Copyright 2025 The ZhipuAI Inc. team and HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import abstractmethod
import torch
@ -113,6 +129,8 @@ class KLinearTorchW8A8A2(KLinearW8A8):
self.weight_offset = None
def forward(self, x: torch.Tensor, bsz_tensor) -> torch.Tensor:
if x.dtype != self.weight.dtype:
x = x.to(self.weight.dtype)
return torch.matmul(x, self.weight)
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str | None = None):

View file

@ -1,3 +1,19 @@
# coding=utf-8
# Copyright (c) 2025. Huawei Technologies Co., Ltd. All rights reserved.
# Copyright 2025 The ZhipuAI Inc. team and HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch_npu
@ -5,6 +21,7 @@ from ktransformers.util.ascend.ascend_utils import allredeuce_warpper
from ktransformers.util.utils import CUR_DEVICE
from ktransformers.operators.base_operator import BaseInjectedModule
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3MLP
from ktransformers.models.modeling_qwen3_moe import Qwen3MoeMLP
class KDeepseekV3MLPW8A8A2V1(BaseInjectedModule, DeepseekV3MLP):
@allredeuce_warpper
@ -71,3 +88,44 @@ class KDeepseekV3MLPW8A8A2V2(BaseInjectedModule, DeepseekV3MLP):
)
down_proj = down_proj.reshape(x.shape)
return down_proj
class KQwen3MoeMLPW8A8A2(BaseInjectedModule, Qwen3MoeMLP):
@allredeuce_warpper
def forward(self, x, is_prefill=None, use_cuda_graph=False):
original_dtype = x.dtype
quant_out, dynamic_scale = torch_npu.npu_dynamic_quant(x)
dynamic_scale = dynamic_scale.view(-1)
quant_out = quant_out.view(-1, quant_out.shape[-1])
gate_x = torch_npu.npu_quant_matmul(
quant_out,
self.orig_module.gate_proj.weight,
self.orig_module.gate_proj.weight_scale,
pertoken_scale=dynamic_scale,
bias=None,
output_dtype=original_dtype,
)
up_x = torch_npu.npu_quant_matmul(
quant_out,
self.orig_module.up_proj.weight,
self.orig_module.up_proj.weight_scale,
pertoken_scale=dynamic_scale,
bias=None,
output_dtype=original_dtype,
)
down_x = torch.nn.functional.silu(gate_x) * up_x
down_quant_out, down_dynamic_scale = torch_npu.npu_dynamic_quant(down_x)
down_dynamic_scale = down_dynamic_scale.view(-1)
down_proj = torch_npu.npu_quant_matmul(
down_quant_out,
self.orig_module.down_proj.weight,
self.orig_module.down_proj.weight_scale,
pertoken_scale=down_dynamic_scale,
bias=None,
output_dtype=original_dtype,
)
down_proj = down_proj.reshape(x.shape)
return down_proj