add balance-serve, support concurrence

This commit is contained in:
Atream 2025-03-31 22:55:32 +08:00
parent 8d0292aa44
commit 25cee5810e
196 changed files with 22077 additions and 565 deletions

View file

@ -359,3 +359,56 @@ class DynamicNTKScalingRotaryEmbedding(
self.orig_module.rope_type,
self.orig_module.config,
)
class RotaryEmbeddingV4(BaseInjectedModule):
def __init__(
self,
key: str,
gguf_loader: GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module,
# device: str = "cuda",
generate_device: str = "cuda",
prefill_device: str = "cuda",
**kwargs,
):
BaseInjectedModule.__init__(
self, key, gguf_loader, config, orig_module, generate_device, **kwargs
)
self.generate_device = generate_device
self.prefill_device = prefill_device
@torch.no_grad()
def forward(self, x, position_ids):
# x: [bs, num_attention_heads, seq_len, head_size]
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
def load(self):
self._init(
dim=self.config.qk_rope_head_dim,
max_position_embeddings=self.config.max_position_embeddings,
base=self.config.rope_theta,
device=self.device,
)
def _init(self, dim, max_position_embeddings, base, device, scaling_factor=1.0):
self.scaling_factor = scaling_factor
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
# self.register_buffer("inv_freq", inv_freq, persistent=False)
# For BC we register cos and sin cached
self.max_seq_len_cached = max_position_embeddings

View file

@ -32,7 +32,8 @@ import os
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled
if flashinfer_enabled:
from ktransformers.operators.flashinfer_wrapper import MLAWrapperSingleton
from flashinfer.mla import BatchMLAPagedAttentionWrapper
from ktransformers.models.custom_cache import KDeepSeekV3Cache
logger = logging.getLogger("attention")
# Copied from transformers.models.llama.modeling_llama.rotate_half
@ -759,3 +760,92 @@ class KLlamaAttention(BaseInjectedModule):
attn_weights = None
return attn_output, attn_weights, past_key_value
class flashinfer_attn(BaseInjectedModule, DeepseekV2Attention):
def __init__(self,
key: str,
gguf_loader : GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module,
prefill_device: str = "cuda",
generate_device: str = "cuda",
chunck_size: int = 1000,
**kwargs):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
self.orig_module.__init__(orig_module.config,
orig_module.layer_idx)
self.chunck_size = chunck_size # TODO, generate chunck_size automatically.
def get_absorbed(self) -> Tuple[torch.Tensor, torch.Tensor]:
if not (hasattr(self, 'q_absorb') and hasattr(self, 'out_absorb')):
kv_b_proj = self.kv_b_proj.weight.view(self.num_heads, -1, self.kv_lora_rank)
q_absorb = kv_b_proj[:, :self.qk_nope_head_dim, :].reshape(-1, self.kv_lora_rank)
out_absorb = kv_b_proj[:, self.qk_nope_head_dim:, :].reshape(-1, self.kv_lora_rank)
self.q_absorb = nn.Linear(self.kv_lora_rank, self.num_heads * self.qk_nope_head_dim,
bias=False, dtype=q_absorb.dtype, device=q_absorb.device)
self.q_absorb.weight.data = q_absorb
self.out_absorb = nn.Linear(self.kv_lora_rank, self.num_heads * self.v_head_dim,
bias=False, dtype=out_absorb.dtype, device=out_absorb.device)
self.out_absorb.weight.data = out_absorb
#del self.orig_module.kv_b_proj
q_absorb = self.q_absorb.weight.view(self.num_heads, self.qk_nope_head_dim, self.kv_lora_rank)
out_absorb = self.out_absorb.weight.view(self.num_heads, self.v_head_dim, self.kv_lora_rank)
return q_absorb, out_absorb
def forward(self,
hidden_states: torch.Tensor,
kv_cache: KDeepSeekV3Cache,
position_ids: torch.Tensor,
wrapper: BatchMLAPagedAttentionWrapper,
num_tokens_tensors: torch.Tensor,
page_idx: torch.Tensor,
page_offset: torch.Tensor,
):
q_len, _ = hidden_states.size()
if self.q_lora_rank is None:
q = self.q_proj(hidden_states, num_tokens_tensors)
else:
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states, num_tokens_tensors), num_tokens_tensors), num_tokens_tensors)
q = q.view(q_len, self.num_heads, self.q_head_dim)
q_nope, q_pe = torch.split(
q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
)
compressed_kv = self.kv_a_proj_with_mqa(hidden_states, num_tokens_tensors)
compressed_kv, k_pe = torch.split(
compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
)
compressed_kv = compressed_kv.contiguous()
compressed_kv = self.kv_a_layernorm(compressed_kv, num_tokens_tensors)
k_pe = k_pe.view(q_len, 1, self.qk_rope_head_dim)
compressed_kv = compressed_kv.view(q_len, 1, self.kv_lora_rank)
cos, sin = self.rotary_emb(q_pe, position_ids.unsqueeze(0))
q_pe, k_pe = apply_rotary_pos_emb(q_pe.unsqueeze(0), k_pe.unsqueeze(0), cos, sin, unsqueeze_dim=2)
q_pe = q_pe.squeeze(0)
if kv_cache is not None:
# page_idx, page_offset = kv_cache.get_page_table(position_ids, q_indptr, kv_indptr, kv_indices)
cache_kwargs = {"sin": sin, "cos": cos, "page_idx": page_idx, "page_offset": page_offset} # Specific to RoPE models
compressed_kv_with_k_pe = kv_cache.update(compressed_kv.unsqueeze(0), k_pe, self.layer_idx, page_idx, page_offset, cache_kwargs)
compressed_kv = compressed_kv_with_k_pe [:, :, :, :self.kv_lora_rank].view(-1, kv_cache.page_size, self.kv_lora_rank)
k_pe = compressed_kv_with_k_pe [:, :, :, self.kv_lora_rank:].view(-1, kv_cache.page_size, self.qk_rope_head_dim)
q_absorb, out_absorb = self.get_absorbed()
q_nope = q_nope.transpose(0, 1) # q_len is 1, no GPU overhead, same below
q_nope = torch.matmul(q_nope, q_absorb) # batched MM
q_nope = q_nope.transpose(0, 1)
# q_nope.squeeze_(1)
# q_pe.squeeze_(1)
attn_output = wrapper.run(q_nope, q_pe, compressed_kv, k_pe).view(q_len, self.num_heads, self.kv_lora_rank)
attn_output = attn_output.transpose(0, 1)
attn_output = torch.matmul(attn_output, out_absorb.mT) # [self.num_heads, q_len, self.v_head_dim]
attn_output = attn_output.transpose(0, 1)
attn_output = attn_output.reshape(q_len, self.num_heads * self.v_head_dim)
attn_output = self.o_proj(attn_output, num_tokens_tensors)
return attn_output

View file

@ -37,6 +37,10 @@ import time
from ktransformers.operators.cpuinfer import CPUInfer
def deduplicate_and_sort(lst):
return sorted(set(lst))
#cuda_graphs = [Config().chunk_size]
cuda_graphs = deduplicate_and_sort([1, 2, 3, Config().max_batch_size, 64, Config().chunk_size])
# class Base(BaseInjectedModule, ABC):
class KExpertsBase(ABC):
def __init__(self, key: str, gguf_loader: GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, device: str = "cuda", **kwargs):
@ -112,6 +116,7 @@ class KExpertsBase(ABC):
tensors[k] = self.gguf_loader.load_gguf_tensor(key + k, device=device)
return tensors
class KExpertsCPU(KExpertsBase):
input_tensor_cpu:Tensor = None
expert_ids_cpu:Tensor = None
@ -119,8 +124,8 @@ class KExpertsCPU(KExpertsBase):
output_cpu:Tensor = None
output_gpu_map:dict = {} # Manage output tensor buffer on different gpu
#stream_map:dict = {} # Manage cuda stream on different gpu
#gguf_loader:GGUFLoader = None
CPU_INFER = None
# @TODO add yaml
CPU_INFER = CPUInfer(Config().cpu_infer)
def __init__(
self,
key: str,
@ -133,11 +138,6 @@ class KExpertsCPU(KExpertsBase):
**kwargs
):
super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)
if KExpertsCPU.CPU_INFER is None:
KExpertsCPU.CPU_INFER = CPUInfer(Config().cpu_infer)
#if KExpertsCPU.gguf_loader is None:
# KExpertsCPU.gguf_loader = GGUFLoader("/mnt/data/model/DeepseekV3-q4km-gguf")
self.gguf_loader = gguf_loader
assert device.lower() == "cpu", "KExpertsCPU can only be loaded on CPU"
self.n_routed_experts = n_routed_experts
self.out_device = out_device
@ -161,7 +161,7 @@ class KExpertsCPU(KExpertsBase):
down_ptr = ctypes.addressof(
ctypes.cast(self.down.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents
)
#print(self.gate_type, self.up_type, self.down_type)
# print(self.gate_qtype, self.up_qtype, self.down_qtype)
n_routed_experts = self.n_routed_experts
# n_routed_experts = len(self.orig_module)
moe_config = MOEConfig(
@ -188,43 +188,83 @@ class KExpertsCPU(KExpertsBase):
self.cpu_infer.submit(self.moe.warm_up())
self.cpu_infer.sync()
if self.out_device not in KExpertsCPU.output_gpu_map:
KExpertsCPU.output_gpu_map[self.out_device] = torch.zeros((self.config.hidden_size), device=self.out_device)
if isinstance(cuda_graphs, list):
KExpertsCPU.output_gpu_map[self.out_device] = [torch.zeros((cuda_graphs[i], self.config.hidden_size), device=self.out_device) for i in range(len(cuda_graphs))]
else:
KExpertsCPU.output_gpu_map[self.out_device] = torch.zeros((cuda_graphs, self.config.hidden_size), device=self.out_device)
if KExpertsCPU.input_tensor_cpu == None:
KExpertsCPU.input_tensor_cpu = torch.zeros((self.config.hidden_size), device="cpu", pin_memory=True)
KExpertsCPU.expert_ids_cpu = torch.zeros((num_experts_per_tok), device="cpu", dtype=torch.long, pin_memory=True)
KExpertsCPU.weights_cpu = torch.zeros((num_experts_per_tok), device="cpu", dtype=torch.float32, pin_memory=True)
KExpertsCPU.output_cpu = torch.zeros((self.config.hidden_size), device="cpu", pin_memory=True, dtype=torch.bfloat16)
if isinstance(cuda_graphs, list):
KExpertsCPU.input_tensor_cpu = [torch.zeros((cuda_graphs[i], self.config.hidden_size), device="cpu", pin_memory=True) for i in range(len(cuda_graphs))]
KExpertsCPU.expert_ids_cpu = [torch.zeros((cuda_graphs[i], num_experts_per_tok), device="cpu", dtype=torch.long, pin_memory=True) for i in range(len(cuda_graphs))]
KExpertsCPU.weights_cpu = [torch.zeros((cuda_graphs[i], num_experts_per_tok), device="cpu", dtype=torch.float32, pin_memory=True) for i in range(len(cuda_graphs))]
KExpertsCPU.output_cpu = [torch.zeros((cuda_graphs[i], self.config.hidden_size), device="cpu", pin_memory=True, dtype=torch.bfloat16) for i in range(len(cuda_graphs))]
KExpertsCPU.bsz_tensor_cpu = [torch.zeros((1), device="cpu", dtype=torch.int32, pin_memory=True) for i in range(len(cuda_graphs))]
else:
KExpertsCPU.input_tensor_cpu = torch.zeros((cuda_graphs, self.config.hidden_size), device="cpu", pin_memory=True)
KExpertsCPU.expert_ids_cpu = torch.zeros((cuda_graphs, num_experts_per_tok), device="cpu", dtype=torch.long, pin_memory=True)
KExpertsCPU.weights_cpu = torch.zeros((cuda_graphs, num_experts_per_tok), device="cpu", dtype=torch.float32, pin_memory=True)
KExpertsCPU.output_cpu = torch.zeros((cuda_graphs, self.config.hidden_size), device="cpu", pin_memory=True, dtype=torch.bfloat16)
KExpertsCPU.bsz_tensor_cpu = torch.zeros((1), device="cpu", dtype=torch.int32, pin_memory=True)
def submit_for_one_decode(self, input_tensor, expert_ids, weights):
KExpertsCPU.input_tensor_cpu.copy_(input_tensor, non_blocking=True)
KExpertsCPU.expert_ids_cpu.copy_(expert_ids, non_blocking=True)
KExpertsCPU.weights_cpu.copy_(weights, non_blocking=True)
self.cpu_infer.submit_with_cuda_stream(torch.cuda.current_stream(self.out_device).cuda_stream, self.moe.forward(1, expert_ids.size(0), KExpertsCPU.expert_ids_cpu.data_ptr(), KExpertsCPU.weights_cpu.data_ptr(), KExpertsCPU.input_tensor_cpu.data_ptr(), KExpertsCPU.output_cpu.data_ptr()))
def sync_for_one_decode(self):
self.cpu_infer.sync_with_cuda_stream(torch.cuda.current_stream(self.out_device).cuda_stream)
KExpertsCPU.output_gpu_map[self.out_device].copy_(KExpertsCPU.output_cpu, non_blocking=True)
return KExpertsCPU.output_gpu_map[self.out_device]
def forward(self, input_tensor, expert_ids, weights):
# generate, capture and run cuda graph
# print(expert_ids)
if input_tensor.size(0)==1 and torch.cuda.is_current_stream_capturing():
# TODO: this branch is unreachable, but the shape of input_tensor([1,hidden_size]) and input_tensor_cpu([hidden_size]) is not compatible
#print("capturing experts")
def submit_for_one_decode(self, input_tensor, expert_ids, weights, bsz_tensor=None, cuda_graph_idx=0):
if bsz_tensor is None:
bsz_tensor = torch.ones(1, device=input_tensor.device, dtype=torch.int32)
if cuda_graph_idx != -1:
KExpertsCPU.input_tensor_cpu[cuda_graph_idx].copy_(input_tensor, non_blocking=True)
KExpertsCPU.expert_ids_cpu[cuda_graph_idx].copy_(expert_ids, non_blocking=True)
KExpertsCPU.weights_cpu[cuda_graph_idx].copy_(weights, non_blocking=True)
KExpertsCPU.bsz_tensor_cpu[cuda_graph_idx].copy_(bsz_tensor, non_blocking=True)
self.cpu_infer.submit_with_cuda_stream(torch.cuda.current_stream(self.out_device).cuda_stream, self.moe.forward(1, expert_ids.size(-1), KExpertsCPU.expert_ids_cpu[cuda_graph_idx].data_ptr(), KExpertsCPU.weights_cpu[cuda_graph_idx].data_ptr(), KExpertsCPU.input_tensor_cpu[cuda_graph_idx].data_ptr(), KExpertsCPU.output_cpu[cuda_graph_idx].data_ptr(), KExpertsCPU.bsz_tensor_cpu[cuda_graph_idx].data_ptr()))
else:
KExpertsCPU.input_tensor_cpu.copy_(input_tensor, non_blocking=True)
KExpertsCPU.expert_ids_cpu.copy_(expert_ids, non_blocking=True)
KExpertsCPU.weights_cpu.copy_(weights, non_blocking=True)
self.cpu_infer.submit_with_cuda_stream(torch.cuda.current_stream().cuda_stream, self.moe.forward(1, expert_ids.size(1), KExpertsCPU.expert_ids_cpu.data_ptr(), KExpertsCPU.weights_cpu.data_ptr(), KExpertsCPU.input_tensor_cpu.data_ptr(), KExpertsCPU.output_cpu.data_ptr()))
self.cpu_infer.sync_with_cuda_stream(torch.cuda.current_stream().cuda_stream)
KExpertsCPU.bsz_tensor_cpu.copy_(bsz_tensor, non_blocking=True)
self.cpu_infer.submit_with_cuda_stream(torch.cuda.current_stream(self.out_device).cuda_stream, self.moe.forward(1, expert_ids.size(-1), KExpertsCPU.expert_ids_cpu.data_ptr(), KExpertsCPU.weights_cpu.data_ptr(), KExpertsCPU.input_tensor_cpu.data_ptr(), KExpertsCPU.output_cpu.data_ptr(), KExpertsCPU.bsz_tensor_cpu.data_ptr()))
def sync_for_one_decode(self, cuda_graph_idx=0):
if cuda_graph_idx != -1:
self.cpu_infer.sync_with_cuda_stream(torch.cuda.current_stream(self.out_device).cuda_stream)
KExpertsCPU.output_gpu_map[self.out_device][cuda_graph_idx].copy_(KExpertsCPU.output_cpu[cuda_graph_idx], non_blocking=True)
return KExpertsCPU.output_gpu_map[self.out_device][cuda_graph_idx]
else:
self.cpu_infer.sync_with_cuda_stream(torch.cuda.current_stream(self.out_device).cuda_stream)
KExpertsCPU.output_gpu_map[self.out_device].copy_(KExpertsCPU.output_cpu, non_blocking=True)
return KExpertsCPU.output_gpu_map[self.out_device]
def forward(self, input_tensor, expert_ids, weights, bsz_tensor=None, cuda_graph_idx=0):
# generate, capture and run cuda graph
# print(expert_ids)
if bsz_tensor is None:
bsz_tensor = torch.tensor([input_tensor.size(0)], device=input_tensor.device, dtype=torch.int32)
if torch.cuda.is_current_stream_capturing():
if cuda_graph_idx != -1:
KExpertsCPU.input_tensor_cpu[cuda_graph_idx].copy_(input_tensor, non_blocking=True)
KExpertsCPU.expert_ids_cpu[cuda_graph_idx].copy_(expert_ids, non_blocking=True)
KExpertsCPU.weights_cpu[cuda_graph_idx].copy_(weights, non_blocking=True)
KExpertsCPU.bsz_tensor_cpu[cuda_graph_idx].copy_(bsz_tensor, non_blocking=True)
self.cpu_infer.submit_with_cuda_stream(torch.cuda.current_stream().cuda_stream, self.moe.forward(expert_ids.size(0), expert_ids.size(-1), KExpertsCPU.expert_ids_cpu[cuda_graph_idx].data_ptr(), KExpertsCPU.weights_cpu[cuda_graph_idx].data_ptr(), KExpertsCPU.input_tensor_cpu[cuda_graph_idx].data_ptr(), KExpertsCPU.output_cpu[cuda_graph_idx].data_ptr(), KExpertsCPU.bsz_tensor_cpu[cuda_graph_idx].data_ptr()))
self.cpu_infer.sync_with_cuda_stream(torch.cuda.current_stream().cuda_stream)
KExpertsCPU.output_gpu_map[self.out_device][cuda_graph_idx].copy_(KExpertsCPU.output_cpu[cuda_graph_idx], non_blocking=True)
return KExpertsCPU.output_gpu_map[self.out_device][cuda_graph_idx]
else:
KExpertsCPU.input_tensor_cpu.copy_(input_tensor, non_blocking=True)
KExpertsCPU.expert_ids_cpu.copy_(expert_ids, non_blocking=True)
KExpertsCPU.weights_cpu.copy_(weights, non_blocking=True)
KExpertsCPU.bsz_tensor_cpu.copy_(bsz_tensor, non_blocking=True)
self.cpu_infer.submit_with_cuda_stream(torch.cuda.current_stream().cuda_stream, self.moe.forward(expert_ids.size(0), expert_ids.size(-1), KExpertsCPU.expert_ids_cpu.data_ptr(), KExpertsCPU.weights_cpu.data_ptr(), KExpertsCPU.input_tensor_cpu.data_ptr(), KExpertsCPU.output_cpu.data_ptr(), KExpertsCPU.bsz_tensor_cpu.data_ptr()))
self.cpu_infer.sync_with_cuda_stream(torch.cuda.current_stream().cuda_stream)
KExpertsCPU.output_gpu_map[self.out_device].copy_(KExpertsCPU.output_cpu, non_blocking=True)
return KExpertsCPU.output_gpu_map[self.out_device]
else:
input_tensor = input_tensor.contiguous().cpu()
expert_ids = expert_ids.contiguous().cpu()
weights = weights.contiguous().to(torch.float32).cpu()
bsz_tensor = bsz_tensor.contiguous().cpu()
output = torch.empty_like(input_tensor).contiguous()
self.cpu_infer.submit(self.moe.forward(expert_ids.size(0), expert_ids.size(1), expert_ids.data_ptr(), weights.data_ptr(), input_tensor.data_ptr(), output.data_ptr()))
self.cpu_infer.submit(self.moe.forward(expert_ids.size(0), expert_ids.size(1), expert_ids.data_ptr(), weights.data_ptr(), input_tensor.data_ptr(), output.data_ptr(), bsz_tensor.data_ptr()))
self.cpu_infer.sync()
return output.to(device=object.__getattribute__(self, "out_device"))
@ -859,6 +899,8 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):
y += y_
return y
@torch.no_grad()
def moe_kexperts(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
outs = self.experts(x, topk_ids, topk_weight)
@ -1013,4 +1055,178 @@ class KMistralSparseMoEBlock(BaseInjectedModule, MixtralSparseMoeBlock):
# the `top_x` tensor here.
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states_cpu.dtype))
return final_hidden_states
return final_hidden_states
class KDeepseekV3MoEV2(BaseInjectedModule, DeepseekV3MoE):
def forward(self, hidden_states, bsz_tensor, cuda_graph_idx=0):
identity = hidden_states
orig_shape = hidden_states.shape
sequence_length = orig_shape[1]
topk_idx, topk_weight = self.gate(hidden_states)
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
# only for generate phase
if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug
self.experts.generate_experts.submit_for_one_decode(hidden_states, topk_idx, topk_weight, bsz_tensor, cuda_graph_idx)
if self.config.n_shared_experts is not None:
y_ = self.shared_experts(identity, bsz_tensor).squeeze(0)
y = self.experts.generate_experts.sync_for_one_decode(cuda_graph_idx).unsqueeze(0)
y += y_
y.resize_(*orig_shape)
return y
if self.config.n_shared_experts is not None:
y_ = self.shared_experts(identity, bsz_tensor).squeeze(0)
if isinstance(self.experts, KExpertsBase):
y = self.moe_on_cpuinfer(hidden_states, topk_idx, topk_weight, bsz_tensor, cuda_graph_idx).view(*orig_shape).to(device=hidden_states.device)
elif hidden_states.size(0) > 10:
# TODO may bugs here
y = (
self.moe_infer(hidden_states, topk_idx, topk_weight)
.view(*orig_shape)
.to(device=hidden_states.device)
)
else:
# TODO may bugs here
y = (
self.moe_infer_simple(hidden_states, topk_idx, topk_weight)
.view(*orig_shape)
.to(device=hidden_states.device)
)
if self.config.n_shared_experts is not None:
y += y_
return y
@torch.no_grad()
def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor, bsz_tensor, cuda_graph_idx=0) -> torch.Tensor:
outs = torch.empty_like(x)
outs = self.experts(x, topk_ids, topk_weight, bsz_tensor, cuda_graph_idx)
return outs
@torch.no_grad()
# TODO may bugs here
def moe_infer_simple(
self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor
) -> torch.Tensor:
"""
x: [num_tokens, hidden_size]
topk_ids, topk_weight: [num_tokens, num_selected_experts]
"""
outs = torch.zeros_like(x)
for token_idx in range(topk_ids.size(0)):
for expert_idx in range(topk_ids.size(1)):
expert = self.experts[topk_ids[token_idx, expert_idx]]
outs[token_idx] += (
expert.forward(x[token_idx]) * topk_weight[token_idx, expert_idx]
)
return outs
@torch.no_grad()
# TODO may bugs here
def moe_infer(self, x, topk_ids, topk_weight):
cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))
cnts.scatter_(1, topk_ids, 1)
tokens_per_expert = cnts.sum(dim=0)
idxs = topk_ids.view(-1).argsort()
sorted_tokens = x[idxs // topk_ids.shape[1]]
tokens_per_expert = tokens_per_expert.cpu().numpy()
outputs = []
start_idx = 0
for i, num_tokens in enumerate(tokens_per_expert):
end_idx = start_idx + num_tokens
if num_tokens == 0:
continue
expert = self.experts[i + self.ep_rank * self.experts_per_rank]
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
expert_out = expert.forward(tokens_for_this_expert)
outputs.append(expert_out)
start_idx = end_idx
outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
new_x = torch.empty_like(outs)
new_x[idxs] = outs
final_out = (
new_x.view(*topk_ids.shape, -1)
.type(topk_weight.dtype)
.mul_(topk_weight.unsqueeze(dim=-1))
.sum(dim=1)
.type(new_x.dtype)
)
return final_out
class KTransformersExpertsV2(BaseInjectedModule, KExpertsBase):
def __init__(self,
key: str,
gguf_loader: GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module,
# device: str = "cuda",
prefill_device:str = "cuda",
prefill_op: str | None = "KExpertsTorch",
generate_device: str = "cpu",
generate_op: str | None = "KExpertsCPU",
**kwargs):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
KExpertsBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
if generate_op is not None:
self.generate_experts = EXPERTS_MAP[generate_op](key, gguf_loader, config, len(orig_module), device=generate_device, **kwargs)
else:
self.generate_experts = None
if prefill_op is not None:
self.prefill_experts = EXPERTS_MAP[prefill_op](key, gguf_loader, config, len(orig_module), device=prefill_device, **kwargs)
else:
self.prefill_experts = None
self.gpu_mlp_type = prefill_op
self.cpu_mlp_type = generate_op
self.mode = InferenceState.UNLOAD
def load(self, w: dict = None, mode: InferenceState = None, warmup: bool = True):
# TODO support w as input
if not mode: mode = InferenceState.GENERATE
if mode == InferenceState.GENERATE:
self.prefill_experts.unload()
self.generate_experts.load(w, warmup=warmup)
self.device = self.generate_experts.device
self.mode = mode
elif mode == InferenceState.PREFILL:
self.generate_experts.unload()
self.prefill_experts.load(w, warmup=warmup)
self.device = self.prefill_experts.device
self.mode = mode
elif mode == InferenceState.UNLOAD:
self.unload()
self.mode = mode
self.device = self.generate_experts.device
else:
raise ValueError("mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD")
def unload(self):
if self.generate_experts is not None:
self.generate_experts.unload()
if self.prefill_experts is not None:
self.prefill_experts.unload()
self.device = self.generate_experts.device
def forward(self, input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx=0):
if self.mode == InferenceState.GENERATE:
assert self.generate_experts is not None, "generate_experts is None"
return self.generate_experts.forward(input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx)
elif self.mode == InferenceState.PREFILL:
assert self.prefill_experts is not None, "prefill_experts is None"
return self.prefill_experts.forward(input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx)
else:
raise ValueError("load or set_inference_mode before forward")
def set_inference_mode(self, mode: InferenceState):
if mode == InferenceState.GENERATE:
self.load(mode=InferenceState.GENERATE, warmup=False)
elif mode == InferenceState.PREFILL:
self.load(mode=InferenceState.PREFILL, warmup=False)
elif mode == InferenceState.UNLOAD:
self.unload()
else:
raise ValueError("mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD")

View file

@ -86,6 +86,7 @@ class MLAWrapper():
self.qo_indptr_buf = torch.empty(max_batch_size+1, dtype=torch.int32, device=device)
self.kv_indptr_buf = torch.empty(max_batch_size+1, dtype=torch.int32, device=device)
self.kv_indices_buf = torch.empty(max_pages, dtype=torch.int32, device=device)
self.batch_size_tensor_buf = torch.tensor([self.max_batch_size], dtype=torch.int32, device=device)
self.kv_len_arr_buf = torch.empty(max_batch_size, dtype=torch.int32, device=device)
else:
self.qo_indptr_buf = None
@ -94,19 +95,22 @@ class MLAWrapper():
self.kv_len_arr_buf = None
self.wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(
self.float_workspace_buffer,
use_cuda_graph=False,
use_cuda_graph=use_cuda_graph,
qo_indptr=self.qo_indptr_buf,
kv_indptr=self.kv_indptr_buf,
kv_indices=self.kv_indices_buf,
kv_len_arr=self.kv_len_arr_buf,
bsz_tensor=self.batch_size_tensor_buf
)
self.need_plan = True
def plan(self,
qo_indptr,
kv_indptr,
kv_indices,
kv_len_arr,
bsz_tensor,
num_heads,
head_dim_ckv,
head_dim_kpe,
@ -138,6 +142,7 @@ class MLAWrapper():
sm_scale,
q_data_type,
kv_data_type,
bsz_tensor
)
def run(self, q_nope, q_pe, ckv, k_pe, return_lse = False):
@ -240,16 +245,17 @@ if __name__ == "__main__":
#checksame()
#exit(0)
max_batch_size = 1
max_pages = 64
max_batch_size = 2
max_batch_tokens = 256
max_pages = 128
page_size = 64
num_heads = 128
# warm-up
kv_len = 4023
q_len = 1
q_nope_buf = torch.randn((q_len, num_heads, 512), dtype=torch.bfloat16, device="cuda")
q_pe_buf = torch.randn((q_len, num_heads, 64), dtype=torch.bfloat16, device="cuda")
q_nope_buf = torch.randn((max_batch_tokens, num_heads, 512), dtype=torch.bfloat16, device="cuda")
q_pe_buf = torch.randn((max_batch_tokens, num_heads, 64), dtype=torch.bfloat16, device="cuda")
kv_buf = torch.randn((max_pages, page_size, 576), dtype=torch.bfloat16, device="cuda")
ckv, k_pe = torch.split(kv_buf, [512, 64], dim=-1)
@ -260,13 +266,19 @@ if __name__ == "__main__":
max_pages,
)
used_pages = (kv_len + page_size - 1)// page_size
kv_len_arr = torch.tensor([kv_len], dtype=torch.int32, device="cuda")
qo_indptr = torch.tensor([0, q_len], dtype=torch.int32, device="cuda")
kv_indptr = torch.tensor([0, used_pages], dtype=torch.int32, device="cuda")
kv_indices = torch.empty(max_pages, dtype=torch.int32, device="cuda")
kv_indices[:used_pages] = torch.arange(0, used_pages, dtype=torch.int32, device="cuda")
bsz_tensor = torch.tensor([1], dtype=torch.int32, device="cuda")
wrapper.plan(
qo_indptr,
None,
None,
kv_indptr,
kv_indices,
kv_len_arr,
bsz_tensor,
128,
512,
64,
@ -276,14 +288,98 @@ if __name__ == "__main__":
torch.bfloat16,
)
attn_output = wrapper.run(q_nope_buf, q_pe_buf, ckv, k_pe)
attn_output = wrapper.run(q_nope_buf[:q_len], q_pe_buf[:q_len], ckv, k_pe)
print(attn_output.shape)
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
attn_output = wrapper.run(q_nope_buf, q_pe_buf, ckv, k_pe)
graph.replay()
q = torch.cat([q_nope_buf, q_pe_buf], dim=-1)
k = (
torch.cat([ckv, k_pe], dim=-1)
.view(-1, 1, 512 + 64)
.repeat_interleave(num_heads, dim=1)
)
v = ckv.view(-1, 1, 512).repeat_interleave(num_heads, dim=1)
attn_ref, lse_ref = attention_ref_torch(
1,
q[:q_len],
k[:kv_len],
v[:kv_len],
True,
192 ** (-0.5)
)
torch.testing.assert_close(attn_output[:q_len], attn_ref, rtol=5e-3, atol=5e-3)
# warm-up finished
kv_len = 512
q_len = 128
pages = max_pages
used_pages = (kv_len + page_size - 1)// page_size
q_nope = torch.randn((q_len*2, num_heads, 512), dtype=torch.bfloat16, device="cuda")
q_nope[q_len:] = q_nope[:q_len]
q_pe = torch.randn((q_len*2, num_heads, 64), dtype=torch.bfloat16, device="cuda")
q_pe[q_len:] = q_pe[:q_len]
kv_cache = torch.randn((max_pages, page_size, 576), dtype=torch.bfloat16, device="cuda")
kv_cache[used_pages:2*used_pages] = kv_cache[:used_pages]
ckv, k_pe = torch.split(kv_cache, [512, 64], dim=-1)
kv_len_arr = torch.tensor([kv_len, kv_len], dtype=torch.int32, device="cuda")
qo_indptr = torch.tensor([0, q_len, q_len*2], dtype=torch.int32, device="cuda")
kv_indptr = torch.tensor([0, used_pages, used_pages*2], dtype=torch.int32, device="cuda")
kv_indices = torch.empty(max_pages, dtype=torch.int32, device="cuda")
kv_indices[:2*used_pages] = torch.arange(0, 2*used_pages, dtype=torch.int32, device="cuda")
bsz_tensor = torch.tensor([2], dtype=torch.int32, device="cuda")
wrapper.plan(
qo_indptr,
kv_indptr,
kv_indices,
kv_len_arr,
bsz_tensor,
128,
512,
64,
page_size,
192 ** (-0.5),
torch.bfloat16,
torch.bfloat16,
)
q_nope_buf.copy_(q_nope)
q_pe_buf.copy_(q_pe)
kv_buf[:pages].copy_(kv_cache)
torch.cuda.synchronize()
graph.replay()
torch.cuda.synchronize()
# ref_torch
q = torch.cat([q_nope, q_pe], dim=-1)
k = (
torch.cat([ckv, k_pe], dim=-1)
.view(-1, 1, 512 + 64)
.repeat_interleave(num_heads, dim=1)
)
v = ckv.view(-1, 1, 512).repeat_interleave(num_heads, dim=1)
attn_ref, lse_ref = attention_ref_torch(
max_batch_size,
q,
k[:2*kv_len],
v[:2*kv_len],
True,
192 ** (-0.5)
)
torch.testing.assert_close(attn_ref[:q_len], attn_ref[q_len:q_len*2], rtol=1e-9, atol=1e-9)
torch.testing.assert_close(attn_output[:q_len], attn_output[q_len:q_len*2], rtol=1e-9, atol=1e-9)
torch.testing.assert_close(attn_output[:q_len], attn_ref[:q_len], rtol=5e-3, atol=5e-3)
torch.testing.assert_close(attn_output[q_len:q_len*2], attn_ref[q_len:q_len*2], rtol=5e-3, atol=5e-3)
#torch.testing.assert_close(attn_output[:q_len], attn_output[q_len:q_len*2], rtol=1e-9, atol=1e-9)
#torch.testing.assert_close(attn_output, attn_ref, rtol=5e-3, atol=5e-3)
exit(0)
for forward_id in range(0, 1):
print("forward_id", forward_id)
for layer_id in range(1):
@ -376,5 +472,4 @@ if __name__ == "__main__":
#file_name = f"./flashinfer_output/layer_{layer_id}_forward_{forward_id}_attn_output.pt"
#ktrans_output = torch.load(file_name)
#torch.testing.assert_close(attn_output, ktrans_output.squeeze(1), rtol=1e-3, atol=1e-3)
print("test past")
print("test past")

View file

@ -249,4 +249,4 @@ class KMoEGateDeepSeekV3(BaseInjectedModule, KMoEGateBase):
if self.weight is not None:
self.weight = None
if self.e_score_correction_bias is not None:
self.e_score_correction_bias = None
self.e_score_correction_bias = None

View file

@ -0,0 +1,78 @@
'''
Date: 2024-11-13 15:05:52
LastEditors: Xie Weiyu ervinxie@qq.com
LastEditTime: 2024-11-25 08:59:19
'''
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
"""Fused operators for normalization layers."""
import logging
from typing import Optional, Tuple, Union
from transformers import PretrainedConfig
import torch
import torch.nn as nn
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3RMSNorm
from ktransformers.operators.base_operator import BaseInjectedModule
from ktransformers.util.custom_gguf import GGUFLoader
from flashinfer.norm import (
fused_add_rmsnorm,
rmsnorm,
)
logger = logging.getLogger(__name__)
class RMSNorm(DeepseekV3RMSNorm, BaseInjectedModule):
def __init__(self,
key: str,
gguf_loader : GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module,
prefill_device: str = "cuda",
generate_device: str = "cuda",
**kwargs):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
self.orig_module.__init__(orig_module.hidden_size,
orig_module.variance_epsilon)
def forward(
self,
x: torch.Tensor,
batch_size_tensor: torch.Tensor = None,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
#return self.forward_native(x, residual)
if batch_size_tensor is None:
return self.forward_native(x)
if residual is not None:
fused_add_rmsnorm(x, residual, self.weight.data, batch_size_tensor, self.variance_epsilon)
#residual = x + residual
#out = rmsnorm(residual, self.weight.data, batch_size_tensor, self.variance_epsilon)
return x, residual
# print(x.shape, self.weight.data.shape, self.variance_epsilon, x.dtype, self.weight.data.dtype, x.device, self.weight.device, x.is_contiguous(), self.weight.data.is_contiguous())
out = rmsnorm(x, self.weight.data, batch_size_tensor,self.variance_epsilon)
return out
def forward_native(
self, hidden_states
):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)

View file

@ -15,14 +15,16 @@ import ctypes
import torch
from torch import Tensor, nn
import KTransformersOps
import vLLMMarlin
from ktransformers.util.custom_gguf import GGUFLoader
from ktransformers.util.utils import InferenceState
from ktransformers.ktransformers_ext.operators.custom_marlin.quantize.utils.marlin_utils import (
MarlinWorkspace,
marlin_quantize,
marlin_quantize,
GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_MIN_THREAD_K,
GPTQ_MARLIN_MAX_PARALLEL,
vllm_marlin_quantize
)
from ktransformers.operators.base_operator import BaseInjectedModule
from transformers.configuration_utils import PretrainedConfig
@ -84,8 +86,10 @@ class KLinearBase(ABC):
if self.gguf_loader.safetensor_loader is not None:
# using safetensor_loader
tensor = self.gguf_loader.safetensor_loader.load_tensor(key+'.weight')
weight_scale_inv = self.gguf_loader.safetensor_loader.load_tensor(key+'.weight_scale_inv')
return nn.Parameter(tensor), nn.Parameter(weight_scale_inv)
if key+'.weight_scale_inv' in self.gguf_loader.safetensor_loader.tensor_file_map:
weight_scale_inv = self.gguf_loader.safetensor_loader.load_tensor(key+'.weight_scale_inv')
return nn.Parameter(tensor), nn.Parameter(weight_scale_inv)
return nn.Parameter(tensor)
elif key + ".weight" in self.gguf_loader.tensor_file_map:
if key + ".bias" in self.gguf_loader.tensor_file_map:
@ -134,7 +138,7 @@ class KLinearTorch(KLinearBase):
self.weight = None
self.has_bias = False
def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
dtype = x.dtype
out_device = x.device
# TODO: support CUDA Graph when using cpu, but CPUInfer is recommended.
@ -178,7 +182,6 @@ class KLinearTorch(KLinearBase):
if self.has_bias:
self.bias = None
class KLinearQ8(KLinearBase):
def __init__(
self,
@ -370,7 +373,7 @@ class KLinearFP8(KLinearBase):
self.dtype = torch.get_default_dtype()
self.block_size = block_size
def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward(self, x: torch.Tensor, bsz_tensor: torch.Tensor) -> torch.Tensor:
x = x.to(self.device)
orig_dtype = x.dtype
x_quantized, scale_x = act_quant(x, self.block_size)
@ -397,8 +400,152 @@ class KLinearFP8(KLinearBase):
self.weight = None
if self.has_bias:
self.bias = None
# TODO: merge two marlin class
class VLinearMarlin(KLinearBase):
marlin_q_w: torch.Tensor
marlin_s: torch.Tensor
g_idx: torch.Tensor
sort_indices: torch.Tensor
has_bias: bool
def __init__(
self,
key: str,
gguf_loader: GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module = None,
device: str = "cuda",
num_bits: int = 4, # 4-bit/8-bit is supported
group_size: int = 64, # -1, 32, 64, 128
act_order: bool = False,
is_k_full=True,
**kwargs,
):
assert device.lower() != "cpu", "Marlin quantized linear only supports GPU device"
super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)
self.num_bits = num_bits
self.group_size = group_size
self.act_order = act_order
self.is_k_full = is_k_full
self.padding = False
self.orin_in_features = self.in_features
self.orin_out_features = self.out_features
if self.in_features%GPTQ_MARLIN_MIN_THREAD_K!=0 or self.out_features%GPTQ_MARLIN_MIN_THREAD_K!=0:
#print(f"warning!, in_features={in_features} or out_features={out_features} is undivisible by GPTQ_MARLIN_MIN_THREAD_K={GPTQ_MARLIN_MIN_THREAD_K} and GPTQ_MARLIN_MIN_THREAD_N={GPTQ_MARLIN_MIN_THREAD_N}, padding")
self.padding = True
self.in_features = (self.in_features+GPTQ_MARLIN_MIN_THREAD_K-1)//GPTQ_MARLIN_MIN_THREAD_K*GPTQ_MARLIN_MIN_THREAD_K
self.out_features = (self.out_features+GPTQ_MARLIN_MIN_THREAD_N-1)//GPTQ_MARLIN_MIN_THREAD_N*GPTQ_MARLIN_MIN_THREAD_N
#print(f"After padding: in_features={in_features}, out_features={out_features}")
self.k = self.in_features
self.n = self.out_features
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None):
if self.loaded: return
if device is None: device = self.device
assert device.lower() != "cpu", "Marlin quantized linear only supports GPU device"
#if self.in_features * self.out_features:
if w is None:
w = self.load_weight(device=device)
if isinstance(w, nn.Parameter):
# pad weight
weight = w.view(self.orin_out_features, self.orin_in_features).T
self.has_bias = False
elif isinstance(w, tuple):
w = list(w)
weight = w[0].view(self.orin_out_features, self.orin_in_features).T
self.bias = w[1].view(self.orin_out_features)
self.bias = w[1]
self.has_bias = True
else:
raise ValueError("Invalid weight type")
weight = weight.to(device)
if self.has_bias:
self.bias = self.bias.to(device)
if self.padding:
padded_weight = torch.zeros(self.in_features, self.out_features, device=self.device)
padded_weight[:self.orin_in_features, :self.orin_out_features] = weight
weight = padded_weight
# Pack Marlin linear
marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
weight, self.num_bits, self.group_size, self.act_order
)
self.workspace = MarlinWorkspace(
self.out_features, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL,self.device
)
self.weight = marlin_q_w
self.marlin_q_w = marlin_q_w
self.marlin_s = marlin_s
self.g_idx = g_idx
self.sort_indices = sort_indices
self.k = weight.shape[0]
self.n = weight.shape[1]
# self.shape_buffer = torch.tensor([60], dtype=torch.int32, device=self.device)
self.loaded = True
def forward(self, x: torch.Tensor, bsz_tensor: torch.Tensor = None) -> torch.Tensor:
if bsz_tensor is None:
bsz_tensor = torch.tensor([x.shape[0]], dtype=torch.int32, device=self.device)
# Only support input x as BF16 and FP16
x = x.to(self.device)
orig_shape = list(x.shape)
orig_dtype = x.dtype
x = x.reshape(-1, orig_shape[-1])
marlin_s = self.marlin_s.to(x.dtype)
sms = -1
x = vLLMMarlin.gptq_marlin_gemm(
x,
self.marlin_q_w,
marlin_s,
self.g_idx,
self.sort_indices,
self.workspace.scratch,
self.num_bits,
bsz_tensor,
# torch.tensor([x.shape[0]], dtype=torch.int32, device=self.device),
x.shape[0],
self.n,
x.shape[-1],
sms,
self.is_k_full,
)
# x = KTransformersOps.gptq_marlin_gemm(
# x,
# self.marlin_q_w,
# marlin_s,
# self.g_idx,
# self.sort_indices,
# self.workspace.scratch,
# self.num_bits,
# x.shape[0],
# self.n,
# x.shape[-1],
# self.is_k_full,
# )
if self.has_bias:
x = x + self.bias
orig_shape[-1] = self.n
return x.reshape(orig_shape).to(orig_dtype)
def unload(self):
if self.has_bias:
self.bias = None
self.marlin_q_w = None
self.marlin_s = None
self.g_idx = None
self.sort_indices = None
self.workspace = None
class KLinearMarlin(KLinearBase):
marlin_q_w: torch.Tensor
marlin_s: torch.Tensor
@ -483,7 +630,7 @@ class KLinearMarlin(KLinearBase):
self.n = weight.shape[1]
self.loaded = True
def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward(self, x: torch.Tensor, bsz_tensor: torch.Tensor=None, **kwargs) -> torch.Tensor:
# Only support input x as BF16 and FP16
x = x.to(self.device)
orig_shape = list(x.shape)
@ -629,12 +776,13 @@ class KLinearCPUInfer(KLinearBase):
if self.w is not None:
self.w = None
if self.has_bias:
self.bias = None
self.bias = None
LINEAR_MAP = {
"KLinearMarlin": KLinearMarlin,
"KLinearTorch": KLinearTorch,
"KLinearCPUInfer": KLinearCPUInfer,
"VLinearMarlin": VLinearMarlin,
"KLinearFP8": KLinearFP8,
"KLinearQ8": KLinearQ8,
}
@ -668,13 +816,13 @@ class KTransformersLinear(BaseInjectedModule, KLinearBase):
self.generate_linear = None
self.mode = InferenceState.UNLOAD
def forward(self, x):
def forward(self, x, bsz_tensor=None):
if self.mode == InferenceState.PREFILL:
assert self.prefill_linear is not None, "cpu linear is not initialized"
y = self.prefill_linear.forward(x)
y = self.prefill_linear.forward(x, bsz_tensor)
else:
assert self.generate_linear is not None, "gpu linear is not initialized"
y = self.generate_linear.forward(x)
y = self.generate_linear.forward(x, bsz_tensor)
return y
def load(self, w: dict | nn.Parameter | tuple | None = None, mode: InferenceState = InferenceState.GENERATE):
@ -717,3 +865,5 @@ class KTransformersLinear(BaseInjectedModule, KLinearBase):
self.unload()
else:
raise ValueError("mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD")

View file

@ -0,0 +1,23 @@
from ktransformers.operators.base_operator import BaseInjectedModule
from ktransformers.util.custom_gguf import GGUFLoader
from transformers import PretrainedConfig
import torch.nn as nn
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3MLP
class kDeepseekV3MLP(DeepseekV3MLP, BaseInjectedModule):
def __init__(self,
key: str,
gguf_loader : GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module,
prefill_device: str = "cuda",
generate_device: str = "cuda",
**kwargs):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
self.orig_module.__init__(orig_module.config,
orig_module.hidden_size, orig_module.intermediate_size)
def forward(self, x, bsz_tensor):
down_proj = self.down_proj(self.act_fn(self.gate_proj(x, bsz_tensor)) * self.up_proj(x, bsz_tensor), bsz_tensor)
return down_proj