mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-09 22:05:30 +00:00
add balance-serve, support concurrence
This commit is contained in:
parent
8d0292aa44
commit
25cee5810e
196 changed files with 22077 additions and 565 deletions
|
@ -359,3 +359,56 @@ class DynamicNTKScalingRotaryEmbedding(
|
|||
self.orig_module.rope_type,
|
||||
self.orig_module.config,
|
||||
)
|
||||
|
||||
|
||||
|
||||
class RotaryEmbeddingV4(BaseInjectedModule):
|
||||
def __init__(
|
||||
self,
|
||||
key: str,
|
||||
gguf_loader: GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module,
|
||||
# device: str = "cuda",
|
||||
generate_device: str = "cuda",
|
||||
prefill_device: str = "cuda",
|
||||
**kwargs,
|
||||
):
|
||||
BaseInjectedModule.__init__(
|
||||
self, key, gguf_loader, config, orig_module, generate_device, **kwargs
|
||||
)
|
||||
self.generate_device = generate_device
|
||||
self.prefill_device = prefill_device
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x, position_ids):
|
||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
||||
position_ids_expanded = position_ids[:, None, :].float()
|
||||
# Force float32 since bfloat16 loses precision on long contexts
|
||||
# See https://github.com/huggingface/transformers/pull/29285
|
||||
device_type = x.device.type
|
||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||
|
||||
def load(self):
|
||||
self._init(
|
||||
dim=self.config.qk_rope_head_dim,
|
||||
max_position_embeddings=self.config.max_position_embeddings,
|
||||
base=self.config.rope_theta,
|
||||
device=self.device,
|
||||
)
|
||||
def _init(self, dim, max_position_embeddings, base, device, scaling_factor=1.0):
|
||||
self.scaling_factor = scaling_factor
|
||||
self.dim = dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
self.inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
|
||||
# self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
# For BC we register cos and sin cached
|
||||
self.max_seq_len_cached = max_position_embeddings
|
|
@ -32,7 +32,8 @@ import os
|
|||
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled
|
||||
if flashinfer_enabled:
|
||||
from ktransformers.operators.flashinfer_wrapper import MLAWrapperSingleton
|
||||
|
||||
from flashinfer.mla import BatchMLAPagedAttentionWrapper
|
||||
from ktransformers.models.custom_cache import KDeepSeekV3Cache
|
||||
logger = logging.getLogger("attention")
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
||||
|
@ -759,3 +760,92 @@ class KLlamaAttention(BaseInjectedModule):
|
|||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
class flashinfer_attn(BaseInjectedModule, DeepseekV2Attention):
|
||||
def __init__(self,
|
||||
key: str,
|
||||
gguf_loader : GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module,
|
||||
prefill_device: str = "cuda",
|
||||
generate_device: str = "cuda",
|
||||
chunck_size: int = 1000,
|
||||
**kwargs):
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
|
||||
self.orig_module.__init__(orig_module.config,
|
||||
orig_module.layer_idx)
|
||||
self.chunck_size = chunck_size # TODO, generate chunck_size automatically.
|
||||
|
||||
|
||||
def get_absorbed(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if not (hasattr(self, 'q_absorb') and hasattr(self, 'out_absorb')):
|
||||
kv_b_proj = self.kv_b_proj.weight.view(self.num_heads, -1, self.kv_lora_rank)
|
||||
q_absorb = kv_b_proj[:, :self.qk_nope_head_dim, :].reshape(-1, self.kv_lora_rank)
|
||||
out_absorb = kv_b_proj[:, self.qk_nope_head_dim:, :].reshape(-1, self.kv_lora_rank)
|
||||
self.q_absorb = nn.Linear(self.kv_lora_rank, self.num_heads * self.qk_nope_head_dim,
|
||||
bias=False, dtype=q_absorb.dtype, device=q_absorb.device)
|
||||
self.q_absorb.weight.data = q_absorb
|
||||
self.out_absorb = nn.Linear(self.kv_lora_rank, self.num_heads * self.v_head_dim,
|
||||
bias=False, dtype=out_absorb.dtype, device=out_absorb.device)
|
||||
self.out_absorb.weight.data = out_absorb
|
||||
#del self.orig_module.kv_b_proj
|
||||
q_absorb = self.q_absorb.weight.view(self.num_heads, self.qk_nope_head_dim, self.kv_lora_rank)
|
||||
out_absorb = self.out_absorb.weight.view(self.num_heads, self.v_head_dim, self.kv_lora_rank)
|
||||
return q_absorb, out_absorb
|
||||
|
||||
|
||||
|
||||
def forward(self,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KDeepSeekV3Cache,
|
||||
position_ids: torch.Tensor,
|
||||
wrapper: BatchMLAPagedAttentionWrapper,
|
||||
num_tokens_tensors: torch.Tensor,
|
||||
page_idx: torch.Tensor,
|
||||
page_offset: torch.Tensor,
|
||||
):
|
||||
q_len, _ = hidden_states.size()
|
||||
|
||||
if self.q_lora_rank is None:
|
||||
q = self.q_proj(hidden_states, num_tokens_tensors)
|
||||
else:
|
||||
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states, num_tokens_tensors), num_tokens_tensors), num_tokens_tensors)
|
||||
q = q.view(q_len, self.num_heads, self.q_head_dim)
|
||||
q_nope, q_pe = torch.split(
|
||||
q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
|
||||
)
|
||||
|
||||
compressed_kv = self.kv_a_proj_with_mqa(hidden_states, num_tokens_tensors)
|
||||
compressed_kv, k_pe = torch.split(
|
||||
compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
|
||||
)
|
||||
compressed_kv = compressed_kv.contiguous()
|
||||
compressed_kv = self.kv_a_layernorm(compressed_kv, num_tokens_tensors)
|
||||
k_pe = k_pe.view(q_len, 1, self.qk_rope_head_dim)
|
||||
compressed_kv = compressed_kv.view(q_len, 1, self.kv_lora_rank)
|
||||
|
||||
cos, sin = self.rotary_emb(q_pe, position_ids.unsqueeze(0))
|
||||
q_pe, k_pe = apply_rotary_pos_emb(q_pe.unsqueeze(0), k_pe.unsqueeze(0), cos, sin, unsqueeze_dim=2)
|
||||
q_pe = q_pe.squeeze(0)
|
||||
if kv_cache is not None:
|
||||
|
||||
# page_idx, page_offset = kv_cache.get_page_table(position_ids, q_indptr, kv_indptr, kv_indices)
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "page_idx": page_idx, "page_offset": page_offset} # Specific to RoPE models
|
||||
compressed_kv_with_k_pe = kv_cache.update(compressed_kv.unsqueeze(0), k_pe, self.layer_idx, page_idx, page_offset, cache_kwargs)
|
||||
compressed_kv = compressed_kv_with_k_pe [:, :, :, :self.kv_lora_rank].view(-1, kv_cache.page_size, self.kv_lora_rank)
|
||||
k_pe = compressed_kv_with_k_pe [:, :, :, self.kv_lora_rank:].view(-1, kv_cache.page_size, self.qk_rope_head_dim)
|
||||
|
||||
q_absorb, out_absorb = self.get_absorbed()
|
||||
q_nope = q_nope.transpose(0, 1) # q_len is 1, no GPU overhead, same below
|
||||
q_nope = torch.matmul(q_nope, q_absorb) # batched MM
|
||||
q_nope = q_nope.transpose(0, 1)
|
||||
# q_nope.squeeze_(1)
|
||||
# q_pe.squeeze_(1)
|
||||
|
||||
attn_output = wrapper.run(q_nope, q_pe, compressed_kv, k_pe).view(q_len, self.num_heads, self.kv_lora_rank)
|
||||
attn_output = attn_output.transpose(0, 1)
|
||||
attn_output = torch.matmul(attn_output, out_absorb.mT) # [self.num_heads, q_len, self.v_head_dim]
|
||||
attn_output = attn_output.transpose(0, 1)
|
||||
attn_output = attn_output.reshape(q_len, self.num_heads * self.v_head_dim)
|
||||
attn_output = self.o_proj(attn_output, num_tokens_tensors)
|
||||
return attn_output
|
|
@ -37,6 +37,10 @@ import time
|
|||
from ktransformers.operators.cpuinfer import CPUInfer
|
||||
|
||||
|
||||
def deduplicate_and_sort(lst):
|
||||
return sorted(set(lst))
|
||||
#cuda_graphs = [Config().chunk_size]
|
||||
cuda_graphs = deduplicate_and_sort([1, 2, 3, Config().max_batch_size, 64, Config().chunk_size])
|
||||
# class Base(BaseInjectedModule, ABC):
|
||||
class KExpertsBase(ABC):
|
||||
def __init__(self, key: str, gguf_loader: GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, device: str = "cuda", **kwargs):
|
||||
|
@ -112,6 +116,7 @@ class KExpertsBase(ABC):
|
|||
tensors[k] = self.gguf_loader.load_gguf_tensor(key + k, device=device)
|
||||
return tensors
|
||||
|
||||
|
||||
class KExpertsCPU(KExpertsBase):
|
||||
input_tensor_cpu:Tensor = None
|
||||
expert_ids_cpu:Tensor = None
|
||||
|
@ -119,8 +124,8 @@ class KExpertsCPU(KExpertsBase):
|
|||
output_cpu:Tensor = None
|
||||
output_gpu_map:dict = {} # Manage output tensor buffer on different gpu
|
||||
#stream_map:dict = {} # Manage cuda stream on different gpu
|
||||
#gguf_loader:GGUFLoader = None
|
||||
CPU_INFER = None
|
||||
# @TODO add yaml
|
||||
CPU_INFER = CPUInfer(Config().cpu_infer)
|
||||
def __init__(
|
||||
self,
|
||||
key: str,
|
||||
|
@ -133,11 +138,6 @@ class KExpertsCPU(KExpertsBase):
|
|||
**kwargs
|
||||
):
|
||||
super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)
|
||||
if KExpertsCPU.CPU_INFER is None:
|
||||
KExpertsCPU.CPU_INFER = CPUInfer(Config().cpu_infer)
|
||||
#if KExpertsCPU.gguf_loader is None:
|
||||
# KExpertsCPU.gguf_loader = GGUFLoader("/mnt/data/model/DeepseekV3-q4km-gguf")
|
||||
self.gguf_loader = gguf_loader
|
||||
assert device.lower() == "cpu", "KExpertsCPU can only be loaded on CPU"
|
||||
self.n_routed_experts = n_routed_experts
|
||||
self.out_device = out_device
|
||||
|
@ -161,7 +161,7 @@ class KExpertsCPU(KExpertsBase):
|
|||
down_ptr = ctypes.addressof(
|
||||
ctypes.cast(self.down.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents
|
||||
)
|
||||
#print(self.gate_type, self.up_type, self.down_type)
|
||||
# print(self.gate_qtype, self.up_qtype, self.down_qtype)
|
||||
n_routed_experts = self.n_routed_experts
|
||||
# n_routed_experts = len(self.orig_module)
|
||||
moe_config = MOEConfig(
|
||||
|
@ -188,43 +188,83 @@ class KExpertsCPU(KExpertsBase):
|
|||
self.cpu_infer.submit(self.moe.warm_up())
|
||||
self.cpu_infer.sync()
|
||||
if self.out_device not in KExpertsCPU.output_gpu_map:
|
||||
KExpertsCPU.output_gpu_map[self.out_device] = torch.zeros((self.config.hidden_size), device=self.out_device)
|
||||
if isinstance(cuda_graphs, list):
|
||||
KExpertsCPU.output_gpu_map[self.out_device] = [torch.zeros((cuda_graphs[i], self.config.hidden_size), device=self.out_device) for i in range(len(cuda_graphs))]
|
||||
else:
|
||||
KExpertsCPU.output_gpu_map[self.out_device] = torch.zeros((cuda_graphs, self.config.hidden_size), device=self.out_device)
|
||||
if KExpertsCPU.input_tensor_cpu == None:
|
||||
KExpertsCPU.input_tensor_cpu = torch.zeros((self.config.hidden_size), device="cpu", pin_memory=True)
|
||||
KExpertsCPU.expert_ids_cpu = torch.zeros((num_experts_per_tok), device="cpu", dtype=torch.long, pin_memory=True)
|
||||
KExpertsCPU.weights_cpu = torch.zeros((num_experts_per_tok), device="cpu", dtype=torch.float32, pin_memory=True)
|
||||
KExpertsCPU.output_cpu = torch.zeros((self.config.hidden_size), device="cpu", pin_memory=True, dtype=torch.bfloat16)
|
||||
if isinstance(cuda_graphs, list):
|
||||
KExpertsCPU.input_tensor_cpu = [torch.zeros((cuda_graphs[i], self.config.hidden_size), device="cpu", pin_memory=True) for i in range(len(cuda_graphs))]
|
||||
KExpertsCPU.expert_ids_cpu = [torch.zeros((cuda_graphs[i], num_experts_per_tok), device="cpu", dtype=torch.long, pin_memory=True) for i in range(len(cuda_graphs))]
|
||||
KExpertsCPU.weights_cpu = [torch.zeros((cuda_graphs[i], num_experts_per_tok), device="cpu", dtype=torch.float32, pin_memory=True) for i in range(len(cuda_graphs))]
|
||||
KExpertsCPU.output_cpu = [torch.zeros((cuda_graphs[i], self.config.hidden_size), device="cpu", pin_memory=True, dtype=torch.bfloat16) for i in range(len(cuda_graphs))]
|
||||
KExpertsCPU.bsz_tensor_cpu = [torch.zeros((1), device="cpu", dtype=torch.int32, pin_memory=True) for i in range(len(cuda_graphs))]
|
||||
else:
|
||||
KExpertsCPU.input_tensor_cpu = torch.zeros((cuda_graphs, self.config.hidden_size), device="cpu", pin_memory=True)
|
||||
KExpertsCPU.expert_ids_cpu = torch.zeros((cuda_graphs, num_experts_per_tok), device="cpu", dtype=torch.long, pin_memory=True)
|
||||
KExpertsCPU.weights_cpu = torch.zeros((cuda_graphs, num_experts_per_tok), device="cpu", dtype=torch.float32, pin_memory=True)
|
||||
KExpertsCPU.output_cpu = torch.zeros((cuda_graphs, self.config.hidden_size), device="cpu", pin_memory=True, dtype=torch.bfloat16)
|
||||
KExpertsCPU.bsz_tensor_cpu = torch.zeros((1), device="cpu", dtype=torch.int32, pin_memory=True)
|
||||
|
||||
def submit_for_one_decode(self, input_tensor, expert_ids, weights):
|
||||
KExpertsCPU.input_tensor_cpu.copy_(input_tensor, non_blocking=True)
|
||||
KExpertsCPU.expert_ids_cpu.copy_(expert_ids, non_blocking=True)
|
||||
KExpertsCPU.weights_cpu.copy_(weights, non_blocking=True)
|
||||
self.cpu_infer.submit_with_cuda_stream(torch.cuda.current_stream(self.out_device).cuda_stream, self.moe.forward(1, expert_ids.size(0), KExpertsCPU.expert_ids_cpu.data_ptr(), KExpertsCPU.weights_cpu.data_ptr(), KExpertsCPU.input_tensor_cpu.data_ptr(), KExpertsCPU.output_cpu.data_ptr()))
|
||||
|
||||
def sync_for_one_decode(self):
|
||||
self.cpu_infer.sync_with_cuda_stream(torch.cuda.current_stream(self.out_device).cuda_stream)
|
||||
KExpertsCPU.output_gpu_map[self.out_device].copy_(KExpertsCPU.output_cpu, non_blocking=True)
|
||||
return KExpertsCPU.output_gpu_map[self.out_device]
|
||||
|
||||
def forward(self, input_tensor, expert_ids, weights):
|
||||
# generate, capture and run cuda graph
|
||||
# print(expert_ids)
|
||||
if input_tensor.size(0)==1 and torch.cuda.is_current_stream_capturing():
|
||||
# TODO: this branch is unreachable, but the shape of input_tensor([1,hidden_size]) and input_tensor_cpu([hidden_size]) is not compatible
|
||||
#print("capturing experts")
|
||||
def submit_for_one_decode(self, input_tensor, expert_ids, weights, bsz_tensor=None, cuda_graph_idx=0):
|
||||
if bsz_tensor is None:
|
||||
bsz_tensor = torch.ones(1, device=input_tensor.device, dtype=torch.int32)
|
||||
if cuda_graph_idx != -1:
|
||||
KExpertsCPU.input_tensor_cpu[cuda_graph_idx].copy_(input_tensor, non_blocking=True)
|
||||
KExpertsCPU.expert_ids_cpu[cuda_graph_idx].copy_(expert_ids, non_blocking=True)
|
||||
KExpertsCPU.weights_cpu[cuda_graph_idx].copy_(weights, non_blocking=True)
|
||||
KExpertsCPU.bsz_tensor_cpu[cuda_graph_idx].copy_(bsz_tensor, non_blocking=True)
|
||||
self.cpu_infer.submit_with_cuda_stream(torch.cuda.current_stream(self.out_device).cuda_stream, self.moe.forward(1, expert_ids.size(-1), KExpertsCPU.expert_ids_cpu[cuda_graph_idx].data_ptr(), KExpertsCPU.weights_cpu[cuda_graph_idx].data_ptr(), KExpertsCPU.input_tensor_cpu[cuda_graph_idx].data_ptr(), KExpertsCPU.output_cpu[cuda_graph_idx].data_ptr(), KExpertsCPU.bsz_tensor_cpu[cuda_graph_idx].data_ptr()))
|
||||
else:
|
||||
KExpertsCPU.input_tensor_cpu.copy_(input_tensor, non_blocking=True)
|
||||
KExpertsCPU.expert_ids_cpu.copy_(expert_ids, non_blocking=True)
|
||||
KExpertsCPU.weights_cpu.copy_(weights, non_blocking=True)
|
||||
self.cpu_infer.submit_with_cuda_stream(torch.cuda.current_stream().cuda_stream, self.moe.forward(1, expert_ids.size(1), KExpertsCPU.expert_ids_cpu.data_ptr(), KExpertsCPU.weights_cpu.data_ptr(), KExpertsCPU.input_tensor_cpu.data_ptr(), KExpertsCPU.output_cpu.data_ptr()))
|
||||
self.cpu_infer.sync_with_cuda_stream(torch.cuda.current_stream().cuda_stream)
|
||||
KExpertsCPU.bsz_tensor_cpu.copy_(bsz_tensor, non_blocking=True)
|
||||
self.cpu_infer.submit_with_cuda_stream(torch.cuda.current_stream(self.out_device).cuda_stream, self.moe.forward(1, expert_ids.size(-1), KExpertsCPU.expert_ids_cpu.data_ptr(), KExpertsCPU.weights_cpu.data_ptr(), KExpertsCPU.input_tensor_cpu.data_ptr(), KExpertsCPU.output_cpu.data_ptr(), KExpertsCPU.bsz_tensor_cpu.data_ptr()))
|
||||
|
||||
|
||||
def sync_for_one_decode(self, cuda_graph_idx=0):
|
||||
if cuda_graph_idx != -1:
|
||||
self.cpu_infer.sync_with_cuda_stream(torch.cuda.current_stream(self.out_device).cuda_stream)
|
||||
KExpertsCPU.output_gpu_map[self.out_device][cuda_graph_idx].copy_(KExpertsCPU.output_cpu[cuda_graph_idx], non_blocking=True)
|
||||
return KExpertsCPU.output_gpu_map[self.out_device][cuda_graph_idx]
|
||||
else:
|
||||
self.cpu_infer.sync_with_cuda_stream(torch.cuda.current_stream(self.out_device).cuda_stream)
|
||||
KExpertsCPU.output_gpu_map[self.out_device].copy_(KExpertsCPU.output_cpu, non_blocking=True)
|
||||
return KExpertsCPU.output_gpu_map[self.out_device]
|
||||
|
||||
def forward(self, input_tensor, expert_ids, weights, bsz_tensor=None, cuda_graph_idx=0):
|
||||
# generate, capture and run cuda graph
|
||||
# print(expert_ids)
|
||||
if bsz_tensor is None:
|
||||
bsz_tensor = torch.tensor([input_tensor.size(0)], device=input_tensor.device, dtype=torch.int32)
|
||||
if torch.cuda.is_current_stream_capturing():
|
||||
if cuda_graph_idx != -1:
|
||||
KExpertsCPU.input_tensor_cpu[cuda_graph_idx].copy_(input_tensor, non_blocking=True)
|
||||
KExpertsCPU.expert_ids_cpu[cuda_graph_idx].copy_(expert_ids, non_blocking=True)
|
||||
KExpertsCPU.weights_cpu[cuda_graph_idx].copy_(weights, non_blocking=True)
|
||||
KExpertsCPU.bsz_tensor_cpu[cuda_graph_idx].copy_(bsz_tensor, non_blocking=True)
|
||||
self.cpu_infer.submit_with_cuda_stream(torch.cuda.current_stream().cuda_stream, self.moe.forward(expert_ids.size(0), expert_ids.size(-1), KExpertsCPU.expert_ids_cpu[cuda_graph_idx].data_ptr(), KExpertsCPU.weights_cpu[cuda_graph_idx].data_ptr(), KExpertsCPU.input_tensor_cpu[cuda_graph_idx].data_ptr(), KExpertsCPU.output_cpu[cuda_graph_idx].data_ptr(), KExpertsCPU.bsz_tensor_cpu[cuda_graph_idx].data_ptr()))
|
||||
self.cpu_infer.sync_with_cuda_stream(torch.cuda.current_stream().cuda_stream)
|
||||
KExpertsCPU.output_gpu_map[self.out_device][cuda_graph_idx].copy_(KExpertsCPU.output_cpu[cuda_graph_idx], non_blocking=True)
|
||||
return KExpertsCPU.output_gpu_map[self.out_device][cuda_graph_idx]
|
||||
|
||||
else:
|
||||
KExpertsCPU.input_tensor_cpu.copy_(input_tensor, non_blocking=True)
|
||||
KExpertsCPU.expert_ids_cpu.copy_(expert_ids, non_blocking=True)
|
||||
KExpertsCPU.weights_cpu.copy_(weights, non_blocking=True)
|
||||
KExpertsCPU.bsz_tensor_cpu.copy_(bsz_tensor, non_blocking=True)
|
||||
self.cpu_infer.submit_with_cuda_stream(torch.cuda.current_stream().cuda_stream, self.moe.forward(expert_ids.size(0), expert_ids.size(-1), KExpertsCPU.expert_ids_cpu.data_ptr(), KExpertsCPU.weights_cpu.data_ptr(), KExpertsCPU.input_tensor_cpu.data_ptr(), KExpertsCPU.output_cpu.data_ptr(), KExpertsCPU.bsz_tensor_cpu.data_ptr()))
|
||||
self.cpu_infer.sync_with_cuda_stream(torch.cuda.current_stream().cuda_stream)
|
||||
KExpertsCPU.output_gpu_map[self.out_device].copy_(KExpertsCPU.output_cpu, non_blocking=True)
|
||||
return KExpertsCPU.output_gpu_map[self.out_device]
|
||||
else:
|
||||
input_tensor = input_tensor.contiguous().cpu()
|
||||
expert_ids = expert_ids.contiguous().cpu()
|
||||
weights = weights.contiguous().to(torch.float32).cpu()
|
||||
bsz_tensor = bsz_tensor.contiguous().cpu()
|
||||
output = torch.empty_like(input_tensor).contiguous()
|
||||
self.cpu_infer.submit(self.moe.forward(expert_ids.size(0), expert_ids.size(1), expert_ids.data_ptr(), weights.data_ptr(), input_tensor.data_ptr(), output.data_ptr()))
|
||||
self.cpu_infer.submit(self.moe.forward(expert_ids.size(0), expert_ids.size(1), expert_ids.data_ptr(), weights.data_ptr(), input_tensor.data_ptr(), output.data_ptr(), bsz_tensor.data_ptr()))
|
||||
self.cpu_infer.sync()
|
||||
return output.to(device=object.__getattribute__(self, "out_device"))
|
||||
|
||||
|
@ -859,6 +899,8 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):
|
|||
y += y_
|
||||
return y
|
||||
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def moe_kexperts(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
|
||||
outs = self.experts(x, topk_ids, topk_weight)
|
||||
|
@ -1013,4 +1055,178 @@ class KMistralSparseMoEBlock(BaseInjectedModule, MixtralSparseMoeBlock):
|
|||
# the `top_x` tensor here.
|
||||
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states_cpu.dtype))
|
||||
|
||||
return final_hidden_states
|
||||
return final_hidden_states
|
||||
|
||||
class KDeepseekV3MoEV2(BaseInjectedModule, DeepseekV3MoE):
|
||||
def forward(self, hidden_states, bsz_tensor, cuda_graph_idx=0):
|
||||
identity = hidden_states
|
||||
orig_shape = hidden_states.shape
|
||||
sequence_length = orig_shape[1]
|
||||
topk_idx, topk_weight = self.gate(hidden_states)
|
||||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||
|
||||
|
||||
# only for generate phase
|
||||
if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug
|
||||
self.experts.generate_experts.submit_for_one_decode(hidden_states, topk_idx, topk_weight, bsz_tensor, cuda_graph_idx)
|
||||
if self.config.n_shared_experts is not None:
|
||||
y_ = self.shared_experts(identity, bsz_tensor).squeeze(0)
|
||||
y = self.experts.generate_experts.sync_for_one_decode(cuda_graph_idx).unsqueeze(0)
|
||||
y += y_
|
||||
y.resize_(*orig_shape)
|
||||
return y
|
||||
|
||||
if self.config.n_shared_experts is not None:
|
||||
y_ = self.shared_experts(identity, bsz_tensor).squeeze(0)
|
||||
|
||||
if isinstance(self.experts, KExpertsBase):
|
||||
y = self.moe_on_cpuinfer(hidden_states, topk_idx, topk_weight, bsz_tensor, cuda_graph_idx).view(*orig_shape).to(device=hidden_states.device)
|
||||
elif hidden_states.size(0) > 10:
|
||||
# TODO may bugs here
|
||||
y = (
|
||||
self.moe_infer(hidden_states, topk_idx, topk_weight)
|
||||
.view(*orig_shape)
|
||||
.to(device=hidden_states.device)
|
||||
)
|
||||
else:
|
||||
# TODO may bugs here
|
||||
y = (
|
||||
self.moe_infer_simple(hidden_states, topk_idx, topk_weight)
|
||||
.view(*orig_shape)
|
||||
.to(device=hidden_states.device)
|
||||
)
|
||||
if self.config.n_shared_experts is not None:
|
||||
y += y_
|
||||
return y
|
||||
|
||||
@torch.no_grad()
|
||||
def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor, bsz_tensor, cuda_graph_idx=0) -> torch.Tensor:
|
||||
outs = torch.empty_like(x)
|
||||
outs = self.experts(x, topk_ids, topk_weight, bsz_tensor, cuda_graph_idx)
|
||||
return outs
|
||||
|
||||
@torch.no_grad()
|
||||
# TODO may bugs here
|
||||
def moe_infer_simple(
|
||||
self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
x: [num_tokens, hidden_size]
|
||||
topk_ids, topk_weight: [num_tokens, num_selected_experts]
|
||||
"""
|
||||
outs = torch.zeros_like(x)
|
||||
for token_idx in range(topk_ids.size(0)):
|
||||
for expert_idx in range(topk_ids.size(1)):
|
||||
expert = self.experts[topk_ids[token_idx, expert_idx]]
|
||||
outs[token_idx] += (
|
||||
expert.forward(x[token_idx]) * topk_weight[token_idx, expert_idx]
|
||||
)
|
||||
return outs
|
||||
|
||||
@torch.no_grad()
|
||||
# TODO may bugs here
|
||||
def moe_infer(self, x, topk_ids, topk_weight):
|
||||
cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))
|
||||
cnts.scatter_(1, topk_ids, 1)
|
||||
tokens_per_expert = cnts.sum(dim=0)
|
||||
idxs = topk_ids.view(-1).argsort()
|
||||
sorted_tokens = x[idxs // topk_ids.shape[1]]
|
||||
tokens_per_expert = tokens_per_expert.cpu().numpy()
|
||||
|
||||
outputs = []
|
||||
start_idx = 0
|
||||
for i, num_tokens in enumerate(tokens_per_expert):
|
||||
end_idx = start_idx + num_tokens
|
||||
if num_tokens == 0:
|
||||
continue
|
||||
expert = self.experts[i + self.ep_rank * self.experts_per_rank]
|
||||
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
|
||||
expert_out = expert.forward(tokens_for_this_expert)
|
||||
outputs.append(expert_out)
|
||||
start_idx = end_idx
|
||||
|
||||
outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
|
||||
|
||||
new_x = torch.empty_like(outs)
|
||||
new_x[idxs] = outs
|
||||
final_out = (
|
||||
new_x.view(*topk_ids.shape, -1)
|
||||
.type(topk_weight.dtype)
|
||||
.mul_(topk_weight.unsqueeze(dim=-1))
|
||||
.sum(dim=1)
|
||||
.type(new_x.dtype)
|
||||
)
|
||||
return final_out
|
||||
|
||||
class KTransformersExpertsV2(BaseInjectedModule, KExpertsBase):
|
||||
def __init__(self,
|
||||
key: str,
|
||||
gguf_loader: GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module,
|
||||
# device: str = "cuda",
|
||||
prefill_device:str = "cuda",
|
||||
prefill_op: str | None = "KExpertsTorch",
|
||||
generate_device: str = "cpu",
|
||||
generate_op: str | None = "KExpertsCPU",
|
||||
**kwargs):
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
|
||||
KExpertsBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
|
||||
if generate_op is not None:
|
||||
self.generate_experts = EXPERTS_MAP[generate_op](key, gguf_loader, config, len(orig_module), device=generate_device, **kwargs)
|
||||
else:
|
||||
self.generate_experts = None
|
||||
if prefill_op is not None:
|
||||
self.prefill_experts = EXPERTS_MAP[prefill_op](key, gguf_loader, config, len(orig_module), device=prefill_device, **kwargs)
|
||||
else:
|
||||
self.prefill_experts = None
|
||||
self.gpu_mlp_type = prefill_op
|
||||
self.cpu_mlp_type = generate_op
|
||||
self.mode = InferenceState.UNLOAD
|
||||
|
||||
def load(self, w: dict = None, mode: InferenceState = None, warmup: bool = True):
|
||||
# TODO support w as input
|
||||
if not mode: mode = InferenceState.GENERATE
|
||||
if mode == InferenceState.GENERATE:
|
||||
self.prefill_experts.unload()
|
||||
self.generate_experts.load(w, warmup=warmup)
|
||||
self.device = self.generate_experts.device
|
||||
self.mode = mode
|
||||
elif mode == InferenceState.PREFILL:
|
||||
self.generate_experts.unload()
|
||||
self.prefill_experts.load(w, warmup=warmup)
|
||||
self.device = self.prefill_experts.device
|
||||
self.mode = mode
|
||||
elif mode == InferenceState.UNLOAD:
|
||||
self.unload()
|
||||
self.mode = mode
|
||||
self.device = self.generate_experts.device
|
||||
else:
|
||||
raise ValueError("mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD")
|
||||
|
||||
def unload(self):
|
||||
if self.generate_experts is not None:
|
||||
self.generate_experts.unload()
|
||||
if self.prefill_experts is not None:
|
||||
self.prefill_experts.unload()
|
||||
self.device = self.generate_experts.device
|
||||
|
||||
def forward(self, input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx=0):
|
||||
if self.mode == InferenceState.GENERATE:
|
||||
assert self.generate_experts is not None, "generate_experts is None"
|
||||
return self.generate_experts.forward(input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx)
|
||||
elif self.mode == InferenceState.PREFILL:
|
||||
assert self.prefill_experts is not None, "prefill_experts is None"
|
||||
return self.prefill_experts.forward(input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx)
|
||||
else:
|
||||
raise ValueError("load or set_inference_mode before forward")
|
||||
|
||||
def set_inference_mode(self, mode: InferenceState):
|
||||
if mode == InferenceState.GENERATE:
|
||||
self.load(mode=InferenceState.GENERATE, warmup=False)
|
||||
elif mode == InferenceState.PREFILL:
|
||||
self.load(mode=InferenceState.PREFILL, warmup=False)
|
||||
elif mode == InferenceState.UNLOAD:
|
||||
self.unload()
|
||||
else:
|
||||
raise ValueError("mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD")
|
||||
|
|
|
@ -86,6 +86,7 @@ class MLAWrapper():
|
|||
self.qo_indptr_buf = torch.empty(max_batch_size+1, dtype=torch.int32, device=device)
|
||||
self.kv_indptr_buf = torch.empty(max_batch_size+1, dtype=torch.int32, device=device)
|
||||
self.kv_indices_buf = torch.empty(max_pages, dtype=torch.int32, device=device)
|
||||
self.batch_size_tensor_buf = torch.tensor([self.max_batch_size], dtype=torch.int32, device=device)
|
||||
self.kv_len_arr_buf = torch.empty(max_batch_size, dtype=torch.int32, device=device)
|
||||
else:
|
||||
self.qo_indptr_buf = None
|
||||
|
@ -94,19 +95,22 @@ class MLAWrapper():
|
|||
self.kv_len_arr_buf = None
|
||||
self.wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(
|
||||
self.float_workspace_buffer,
|
||||
use_cuda_graph=False,
|
||||
use_cuda_graph=use_cuda_graph,
|
||||
qo_indptr=self.qo_indptr_buf,
|
||||
kv_indptr=self.kv_indptr_buf,
|
||||
kv_indices=self.kv_indices_buf,
|
||||
kv_len_arr=self.kv_len_arr_buf,
|
||||
bsz_tensor=self.batch_size_tensor_buf
|
||||
)
|
||||
self.need_plan = True
|
||||
|
||||
|
||||
def plan(self,
|
||||
qo_indptr,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_len_arr,
|
||||
bsz_tensor,
|
||||
num_heads,
|
||||
head_dim_ckv,
|
||||
head_dim_kpe,
|
||||
|
@ -138,6 +142,7 @@ class MLAWrapper():
|
|||
sm_scale,
|
||||
q_data_type,
|
||||
kv_data_type,
|
||||
bsz_tensor
|
||||
)
|
||||
|
||||
def run(self, q_nope, q_pe, ckv, k_pe, return_lse = False):
|
||||
|
@ -240,16 +245,17 @@ if __name__ == "__main__":
|
|||
#checksame()
|
||||
#exit(0)
|
||||
|
||||
max_batch_size = 1
|
||||
max_pages = 64
|
||||
max_batch_size = 2
|
||||
max_batch_tokens = 256
|
||||
max_pages = 128
|
||||
page_size = 64
|
||||
num_heads = 128
|
||||
|
||||
# warm-up
|
||||
kv_len = 4023
|
||||
q_len = 1
|
||||
q_nope_buf = torch.randn((q_len, num_heads, 512), dtype=torch.bfloat16, device="cuda")
|
||||
q_pe_buf = torch.randn((q_len, num_heads, 64), dtype=torch.bfloat16, device="cuda")
|
||||
q_nope_buf = torch.randn((max_batch_tokens, num_heads, 512), dtype=torch.bfloat16, device="cuda")
|
||||
q_pe_buf = torch.randn((max_batch_tokens, num_heads, 64), dtype=torch.bfloat16, device="cuda")
|
||||
kv_buf = torch.randn((max_pages, page_size, 576), dtype=torch.bfloat16, device="cuda")
|
||||
ckv, k_pe = torch.split(kv_buf, [512, 64], dim=-1)
|
||||
|
||||
|
@ -260,13 +266,19 @@ if __name__ == "__main__":
|
|||
max_pages,
|
||||
)
|
||||
|
||||
used_pages = (kv_len + page_size - 1)// page_size
|
||||
kv_len_arr = torch.tensor([kv_len], dtype=torch.int32, device="cuda")
|
||||
qo_indptr = torch.tensor([0, q_len], dtype=torch.int32, device="cuda")
|
||||
kv_indptr = torch.tensor([0, used_pages], dtype=torch.int32, device="cuda")
|
||||
kv_indices = torch.empty(max_pages, dtype=torch.int32, device="cuda")
|
||||
kv_indices[:used_pages] = torch.arange(0, used_pages, dtype=torch.int32, device="cuda")
|
||||
bsz_tensor = torch.tensor([1], dtype=torch.int32, device="cuda")
|
||||
wrapper.plan(
|
||||
qo_indptr,
|
||||
None,
|
||||
None,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_len_arr,
|
||||
bsz_tensor,
|
||||
128,
|
||||
512,
|
||||
64,
|
||||
|
@ -276,14 +288,98 @@ if __name__ == "__main__":
|
|||
torch.bfloat16,
|
||||
)
|
||||
|
||||
attn_output = wrapper.run(q_nope_buf, q_pe_buf, ckv, k_pe)
|
||||
attn_output = wrapper.run(q_nope_buf[:q_len], q_pe_buf[:q_len], ckv, k_pe)
|
||||
print(attn_output.shape)
|
||||
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph):
|
||||
attn_output = wrapper.run(q_nope_buf, q_pe_buf, ckv, k_pe)
|
||||
graph.replay()
|
||||
|
||||
q = torch.cat([q_nope_buf, q_pe_buf], dim=-1)
|
||||
k = (
|
||||
torch.cat([ckv, k_pe], dim=-1)
|
||||
.view(-1, 1, 512 + 64)
|
||||
.repeat_interleave(num_heads, dim=1)
|
||||
)
|
||||
v = ckv.view(-1, 1, 512).repeat_interleave(num_heads, dim=1)
|
||||
attn_ref, lse_ref = attention_ref_torch(
|
||||
1,
|
||||
q[:q_len],
|
||||
k[:kv_len],
|
||||
v[:kv_len],
|
||||
True,
|
||||
192 ** (-0.5)
|
||||
)
|
||||
torch.testing.assert_close(attn_output[:q_len], attn_ref, rtol=5e-3, atol=5e-3)
|
||||
# warm-up finished
|
||||
|
||||
kv_len = 512
|
||||
q_len = 128
|
||||
pages = max_pages
|
||||
used_pages = (kv_len + page_size - 1)// page_size
|
||||
q_nope = torch.randn((q_len*2, num_heads, 512), dtype=torch.bfloat16, device="cuda")
|
||||
q_nope[q_len:] = q_nope[:q_len]
|
||||
q_pe = torch.randn((q_len*2, num_heads, 64), dtype=torch.bfloat16, device="cuda")
|
||||
q_pe[q_len:] = q_pe[:q_len]
|
||||
kv_cache = torch.randn((max_pages, page_size, 576), dtype=torch.bfloat16, device="cuda")
|
||||
kv_cache[used_pages:2*used_pages] = kv_cache[:used_pages]
|
||||
ckv, k_pe = torch.split(kv_cache, [512, 64], dim=-1)
|
||||
|
||||
kv_len_arr = torch.tensor([kv_len, kv_len], dtype=torch.int32, device="cuda")
|
||||
qo_indptr = torch.tensor([0, q_len, q_len*2], dtype=torch.int32, device="cuda")
|
||||
kv_indptr = torch.tensor([0, used_pages, used_pages*2], dtype=torch.int32, device="cuda")
|
||||
kv_indices = torch.empty(max_pages, dtype=torch.int32, device="cuda")
|
||||
kv_indices[:2*used_pages] = torch.arange(0, 2*used_pages, dtype=torch.int32, device="cuda")
|
||||
bsz_tensor = torch.tensor([2], dtype=torch.int32, device="cuda")
|
||||
wrapper.plan(
|
||||
qo_indptr,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_len_arr,
|
||||
bsz_tensor,
|
||||
128,
|
||||
512,
|
||||
64,
|
||||
page_size,
|
||||
192 ** (-0.5),
|
||||
torch.bfloat16,
|
||||
torch.bfloat16,
|
||||
)
|
||||
|
||||
q_nope_buf.copy_(q_nope)
|
||||
q_pe_buf.copy_(q_pe)
|
||||
kv_buf[:pages].copy_(kv_cache)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
graph.replay()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# ref_torch
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)
|
||||
k = (
|
||||
torch.cat([ckv, k_pe], dim=-1)
|
||||
.view(-1, 1, 512 + 64)
|
||||
.repeat_interleave(num_heads, dim=1)
|
||||
)
|
||||
v = ckv.view(-1, 1, 512).repeat_interleave(num_heads, dim=1)
|
||||
attn_ref, lse_ref = attention_ref_torch(
|
||||
max_batch_size,
|
||||
q,
|
||||
k[:2*kv_len],
|
||||
v[:2*kv_len],
|
||||
True,
|
||||
192 ** (-0.5)
|
||||
)
|
||||
|
||||
torch.testing.assert_close(attn_ref[:q_len], attn_ref[q_len:q_len*2], rtol=1e-9, atol=1e-9)
|
||||
torch.testing.assert_close(attn_output[:q_len], attn_output[q_len:q_len*2], rtol=1e-9, atol=1e-9)
|
||||
torch.testing.assert_close(attn_output[:q_len], attn_ref[:q_len], rtol=5e-3, atol=5e-3)
|
||||
torch.testing.assert_close(attn_output[q_len:q_len*2], attn_ref[q_len:q_len*2], rtol=5e-3, atol=5e-3)
|
||||
#torch.testing.assert_close(attn_output[:q_len], attn_output[q_len:q_len*2], rtol=1e-9, atol=1e-9)
|
||||
#torch.testing.assert_close(attn_output, attn_ref, rtol=5e-3, atol=5e-3)
|
||||
|
||||
exit(0)
|
||||
|
||||
for forward_id in range(0, 1):
|
||||
print("forward_id", forward_id)
|
||||
for layer_id in range(1):
|
||||
|
@ -376,5 +472,4 @@ if __name__ == "__main__":
|
|||
#file_name = f"./flashinfer_output/layer_{layer_id}_forward_{forward_id}_attn_output.pt"
|
||||
#ktrans_output = torch.load(file_name)
|
||||
#torch.testing.assert_close(attn_output, ktrans_output.squeeze(1), rtol=1e-3, atol=1e-3)
|
||||
print("test past")
|
||||
|
||||
print("test past")
|
|
@ -249,4 +249,4 @@ class KMoEGateDeepSeekV3(BaseInjectedModule, KMoEGateBase):
|
|||
if self.weight is not None:
|
||||
self.weight = None
|
||||
if self.e_score_correction_bias is not None:
|
||||
self.e_score_correction_bias = None
|
||||
self.e_score_correction_bias = None
|
78
ktransformers/operators/layernorm.py
Normal file
78
ktransformers/operators/layernorm.py
Normal file
|
@ -0,0 +1,78 @@
|
|||
'''
|
||||
Date: 2024-11-13 15:05:52
|
||||
LastEditors: Xie Weiyu ervinxie@qq.com
|
||||
LastEditTime: 2024-11-25 08:59:19
|
||||
'''
|
||||
"""
|
||||
Copyright 2023-2024 SGLang Team
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
"""Fused operators for normalization layers."""
|
||||
|
||||
import logging
|
||||
from typing import Optional, Tuple, Union
|
||||
from transformers import PretrainedConfig
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3RMSNorm
|
||||
from ktransformers.operators.base_operator import BaseInjectedModule
|
||||
from ktransformers.util.custom_gguf import GGUFLoader
|
||||
from flashinfer.norm import (
|
||||
fused_add_rmsnorm,
|
||||
rmsnorm,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RMSNorm(DeepseekV3RMSNorm, BaseInjectedModule):
|
||||
def __init__(self,
|
||||
key: str,
|
||||
gguf_loader : GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module,
|
||||
prefill_device: str = "cuda",
|
||||
generate_device: str = "cuda",
|
||||
**kwargs):
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
|
||||
self.orig_module.__init__(orig_module.hidden_size,
|
||||
orig_module.variance_epsilon)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
batch_size_tensor: torch.Tensor = None,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
#return self.forward_native(x, residual)
|
||||
if batch_size_tensor is None:
|
||||
return self.forward_native(x)
|
||||
if residual is not None:
|
||||
fused_add_rmsnorm(x, residual, self.weight.data, batch_size_tensor, self.variance_epsilon)
|
||||
#residual = x + residual
|
||||
#out = rmsnorm(residual, self.weight.data, batch_size_tensor, self.variance_epsilon)
|
||||
return x, residual
|
||||
# print(x.shape, self.weight.data.shape, self.variance_epsilon, x.dtype, self.weight.data.dtype, x.device, self.weight.device, x.is_contiguous(), self.weight.data.is_contiguous())
|
||||
out = rmsnorm(x, self.weight.data, batch_size_tensor,self.variance_epsilon)
|
||||
return out
|
||||
|
||||
def forward_native(
|
||||
self, hidden_states
|
||||
):
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
|
@ -15,14 +15,16 @@ import ctypes
|
|||
import torch
|
||||
from torch import Tensor, nn
|
||||
import KTransformersOps
|
||||
import vLLMMarlin
|
||||
from ktransformers.util.custom_gguf import GGUFLoader
|
||||
from ktransformers.util.utils import InferenceState
|
||||
from ktransformers.ktransformers_ext.operators.custom_marlin.quantize.utils.marlin_utils import (
|
||||
MarlinWorkspace,
|
||||
marlin_quantize,
|
||||
marlin_quantize,
|
||||
GPTQ_MARLIN_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_MIN_THREAD_K,
|
||||
GPTQ_MARLIN_MAX_PARALLEL,
|
||||
vllm_marlin_quantize
|
||||
)
|
||||
from ktransformers.operators.base_operator import BaseInjectedModule
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
@ -84,8 +86,10 @@ class KLinearBase(ABC):
|
|||
if self.gguf_loader.safetensor_loader is not None:
|
||||
# using safetensor_loader
|
||||
tensor = self.gguf_loader.safetensor_loader.load_tensor(key+'.weight')
|
||||
weight_scale_inv = self.gguf_loader.safetensor_loader.load_tensor(key+'.weight_scale_inv')
|
||||
return nn.Parameter(tensor), nn.Parameter(weight_scale_inv)
|
||||
if key+'.weight_scale_inv' in self.gguf_loader.safetensor_loader.tensor_file_map:
|
||||
weight_scale_inv = self.gguf_loader.safetensor_loader.load_tensor(key+'.weight_scale_inv')
|
||||
return nn.Parameter(tensor), nn.Parameter(weight_scale_inv)
|
||||
return nn.Parameter(tensor)
|
||||
|
||||
elif key + ".weight" in self.gguf_loader.tensor_file_map:
|
||||
if key + ".bias" in self.gguf_loader.tensor_file_map:
|
||||
|
@ -134,7 +138,7 @@ class KLinearTorch(KLinearBase):
|
|||
self.weight = None
|
||||
self.has_bias = False
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
dtype = x.dtype
|
||||
out_device = x.device
|
||||
# TODO: support CUDA Graph when using cpu, but CPUInfer is recommended.
|
||||
|
@ -178,7 +182,6 @@ class KLinearTorch(KLinearBase):
|
|||
if self.has_bias:
|
||||
self.bias = None
|
||||
|
||||
|
||||
class KLinearQ8(KLinearBase):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -370,7 +373,7 @@ class KLinearFP8(KLinearBase):
|
|||
self.dtype = torch.get_default_dtype()
|
||||
self.block_size = block_size
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
def forward(self, x: torch.Tensor, bsz_tensor: torch.Tensor) -> torch.Tensor:
|
||||
x = x.to(self.device)
|
||||
orig_dtype = x.dtype
|
||||
x_quantized, scale_x = act_quant(x, self.block_size)
|
||||
|
@ -397,8 +400,152 @@ class KLinearFP8(KLinearBase):
|
|||
self.weight = None
|
||||
if self.has_bias:
|
||||
self.bias = None
|
||||
|
||||
# TODO: merge two marlin class
|
||||
|
||||
class VLinearMarlin(KLinearBase):
|
||||
marlin_q_w: torch.Tensor
|
||||
marlin_s: torch.Tensor
|
||||
g_idx: torch.Tensor
|
||||
sort_indices: torch.Tensor
|
||||
has_bias: bool
|
||||
def __init__(
|
||||
self,
|
||||
key: str,
|
||||
gguf_loader: GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module = None,
|
||||
device: str = "cuda",
|
||||
num_bits: int = 4, # 4-bit/8-bit is supported
|
||||
group_size: int = 64, # -1, 32, 64, 128
|
||||
act_order: bool = False,
|
||||
is_k_full=True,
|
||||
**kwargs,
|
||||
):
|
||||
assert device.lower() != "cpu", "Marlin quantized linear only supports GPU device"
|
||||
super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)
|
||||
self.num_bits = num_bits
|
||||
self.group_size = group_size
|
||||
self.act_order = act_order
|
||||
self.is_k_full = is_k_full
|
||||
self.padding = False
|
||||
self.orin_in_features = self.in_features
|
||||
self.orin_out_features = self.out_features
|
||||
if self.in_features%GPTQ_MARLIN_MIN_THREAD_K!=0 or self.out_features%GPTQ_MARLIN_MIN_THREAD_K!=0:
|
||||
#print(f"warning!, in_features={in_features} or out_features={out_features} is undivisible by GPTQ_MARLIN_MIN_THREAD_K={GPTQ_MARLIN_MIN_THREAD_K} and GPTQ_MARLIN_MIN_THREAD_N={GPTQ_MARLIN_MIN_THREAD_N}, padding")
|
||||
self.padding = True
|
||||
self.in_features = (self.in_features+GPTQ_MARLIN_MIN_THREAD_K-1)//GPTQ_MARLIN_MIN_THREAD_K*GPTQ_MARLIN_MIN_THREAD_K
|
||||
self.out_features = (self.out_features+GPTQ_MARLIN_MIN_THREAD_N-1)//GPTQ_MARLIN_MIN_THREAD_N*GPTQ_MARLIN_MIN_THREAD_N
|
||||
#print(f"After padding: in_features={in_features}, out_features={out_features}")
|
||||
|
||||
self.k = self.in_features
|
||||
self.n = self.out_features
|
||||
|
||||
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None):
|
||||
if self.loaded: return
|
||||
if device is None: device = self.device
|
||||
assert device.lower() != "cpu", "Marlin quantized linear only supports GPU device"
|
||||
|
||||
#if self.in_features * self.out_features:
|
||||
if w is None:
|
||||
w = self.load_weight(device=device)
|
||||
|
||||
if isinstance(w, nn.Parameter):
|
||||
# pad weight
|
||||
weight = w.view(self.orin_out_features, self.orin_in_features).T
|
||||
self.has_bias = False
|
||||
elif isinstance(w, tuple):
|
||||
w = list(w)
|
||||
weight = w[0].view(self.orin_out_features, self.orin_in_features).T
|
||||
self.bias = w[1].view(self.orin_out_features)
|
||||
self.bias = w[1]
|
||||
self.has_bias = True
|
||||
else:
|
||||
raise ValueError("Invalid weight type")
|
||||
weight = weight.to(device)
|
||||
if self.has_bias:
|
||||
self.bias = self.bias.to(device)
|
||||
|
||||
if self.padding:
|
||||
padded_weight = torch.zeros(self.in_features, self.out_features, device=self.device)
|
||||
padded_weight[:self.orin_in_features, :self.orin_out_features] = weight
|
||||
weight = padded_weight
|
||||
|
||||
# Pack Marlin linear
|
||||
marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
|
||||
weight, self.num_bits, self.group_size, self.act_order
|
||||
)
|
||||
self.workspace = MarlinWorkspace(
|
||||
self.out_features, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL,self.device
|
||||
)
|
||||
self.weight = marlin_q_w
|
||||
self.marlin_q_w = marlin_q_w
|
||||
self.marlin_s = marlin_s
|
||||
self.g_idx = g_idx
|
||||
self.sort_indices = sort_indices
|
||||
self.k = weight.shape[0]
|
||||
self.n = weight.shape[1]
|
||||
# self.shape_buffer = torch.tensor([60], dtype=torch.int32, device=self.device)
|
||||
self.loaded = True
|
||||
|
||||
|
||||
def forward(self, x: torch.Tensor, bsz_tensor: torch.Tensor = None) -> torch.Tensor:
|
||||
if bsz_tensor is None:
|
||||
bsz_tensor = torch.tensor([x.shape[0]], dtype=torch.int32, device=self.device)
|
||||
|
||||
|
||||
# Only support input x as BF16 and FP16
|
||||
x = x.to(self.device)
|
||||
orig_shape = list(x.shape)
|
||||
orig_dtype = x.dtype
|
||||
x = x.reshape(-1, orig_shape[-1])
|
||||
marlin_s = self.marlin_s.to(x.dtype)
|
||||
sms = -1
|
||||
|
||||
x = vLLMMarlin.gptq_marlin_gemm(
|
||||
x,
|
||||
self.marlin_q_w,
|
||||
marlin_s,
|
||||
self.g_idx,
|
||||
self.sort_indices,
|
||||
self.workspace.scratch,
|
||||
self.num_bits,
|
||||
bsz_tensor,
|
||||
# torch.tensor([x.shape[0]], dtype=torch.int32, device=self.device),
|
||||
x.shape[0],
|
||||
self.n,
|
||||
x.shape[-1],
|
||||
sms,
|
||||
self.is_k_full,
|
||||
)
|
||||
# x = KTransformersOps.gptq_marlin_gemm(
|
||||
# x,
|
||||
# self.marlin_q_w,
|
||||
# marlin_s,
|
||||
# self.g_idx,
|
||||
# self.sort_indices,
|
||||
# self.workspace.scratch,
|
||||
# self.num_bits,
|
||||
# x.shape[0],
|
||||
# self.n,
|
||||
# x.shape[-1],
|
||||
# self.is_k_full,
|
||||
# )
|
||||
if self.has_bias:
|
||||
x = x + self.bias
|
||||
orig_shape[-1] = self.n
|
||||
return x.reshape(orig_shape).to(orig_dtype)
|
||||
|
||||
def unload(self):
|
||||
|
||||
if self.has_bias:
|
||||
self.bias = None
|
||||
self.marlin_q_w = None
|
||||
self.marlin_s = None
|
||||
self.g_idx = None
|
||||
self.sort_indices = None
|
||||
self.workspace = None
|
||||
|
||||
class KLinearMarlin(KLinearBase):
|
||||
marlin_q_w: torch.Tensor
|
||||
marlin_s: torch.Tensor
|
||||
|
@ -483,7 +630,7 @@ class KLinearMarlin(KLinearBase):
|
|||
self.n = weight.shape[1]
|
||||
self.loaded = True
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
def forward(self, x: torch.Tensor, bsz_tensor: torch.Tensor=None, **kwargs) -> torch.Tensor:
|
||||
# Only support input x as BF16 and FP16
|
||||
x = x.to(self.device)
|
||||
orig_shape = list(x.shape)
|
||||
|
@ -629,12 +776,13 @@ class KLinearCPUInfer(KLinearBase):
|
|||
if self.w is not None:
|
||||
self.w = None
|
||||
if self.has_bias:
|
||||
self.bias = None
|
||||
self.bias = None
|
||||
|
||||
LINEAR_MAP = {
|
||||
"KLinearMarlin": KLinearMarlin,
|
||||
"KLinearTorch": KLinearTorch,
|
||||
"KLinearCPUInfer": KLinearCPUInfer,
|
||||
"VLinearMarlin": VLinearMarlin,
|
||||
"KLinearFP8": KLinearFP8,
|
||||
"KLinearQ8": KLinearQ8,
|
||||
}
|
||||
|
@ -668,13 +816,13 @@ class KTransformersLinear(BaseInjectedModule, KLinearBase):
|
|||
self.generate_linear = None
|
||||
self.mode = InferenceState.UNLOAD
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, bsz_tensor=None):
|
||||
if self.mode == InferenceState.PREFILL:
|
||||
assert self.prefill_linear is not None, "cpu linear is not initialized"
|
||||
y = self.prefill_linear.forward(x)
|
||||
y = self.prefill_linear.forward(x, bsz_tensor)
|
||||
else:
|
||||
assert self.generate_linear is not None, "gpu linear is not initialized"
|
||||
y = self.generate_linear.forward(x)
|
||||
y = self.generate_linear.forward(x, bsz_tensor)
|
||||
return y
|
||||
|
||||
def load(self, w: dict | nn.Parameter | tuple | None = None, mode: InferenceState = InferenceState.GENERATE):
|
||||
|
@ -717,3 +865,5 @@ class KTransformersLinear(BaseInjectedModule, KLinearBase):
|
|||
self.unload()
|
||||
else:
|
||||
raise ValueError("mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD")
|
||||
|
||||
|
||||
|
|
23
ktransformers/operators/mlp.py
Normal file
23
ktransformers/operators/mlp.py
Normal file
|
@ -0,0 +1,23 @@
|
|||
|
||||
from ktransformers.operators.base_operator import BaseInjectedModule
|
||||
from ktransformers.util.custom_gguf import GGUFLoader
|
||||
from transformers import PretrainedConfig
|
||||
import torch.nn as nn
|
||||
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3MLP
|
||||
|
||||
|
||||
class kDeepseekV3MLP(DeepseekV3MLP, BaseInjectedModule):
|
||||
def __init__(self,
|
||||
key: str,
|
||||
gguf_loader : GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module,
|
||||
prefill_device: str = "cuda",
|
||||
generate_device: str = "cuda",
|
||||
**kwargs):
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
|
||||
self.orig_module.__init__(orig_module.config,
|
||||
orig_module.hidden_size, orig_module.intermediate_size)
|
||||
def forward(self, x, bsz_tensor):
|
||||
down_proj = self.down_proj(self.act_fn(self.gate_proj(x, bsz_tensor)) * self.up_proj(x, bsz_tensor), bsz_tensor)
|
||||
return down_proj
|
Loading…
Add table
Add a link
Reference in a new issue