mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-04-28 03:39:48 +00:00
Merge pull request #1545 from kvcache-ai/develop-cht
update kt-kernel: support Expert Deferral mechanism
This commit is contained in:
commit
b09e99fd87
5 changed files with 162 additions and 51 deletions
|
|
@ -94,22 +94,22 @@ class CPUInfer {
|
|||
|
||||
struct SyncArgs {
|
||||
CPUInfer* cpuinfer;
|
||||
size_t n;
|
||||
size_t allow_n_pending;
|
||||
};
|
||||
|
||||
static void sync_(void* sync_args) {
|
||||
SyncArgs* args = (SyncArgs*)sync_args;
|
||||
args->cpuinfer->task_queue_->sync(args->n);
|
||||
args->cpuinfer->task_queue_->sync(args->allow_n_pending);
|
||||
}
|
||||
|
||||
void sync(size_t n = 0) {
|
||||
SyncArgs* args = new SyncArgs{this, n};
|
||||
void sync(size_t allow_n_pending = 0) {
|
||||
SyncArgs* args = new SyncArgs{this, allow_n_pending};
|
||||
sync_(args);
|
||||
}
|
||||
#ifndef KTRANSFORMERS_CPU_ONLY
|
||||
void sync_with_cuda_stream(intptr_t user_cuda_stream, size_t n = 0) {
|
||||
void sync_with_cuda_stream(intptr_t user_cuda_stream, size_t allow_n_pending = 0) {
|
||||
#if defined(KTRANSFORMERS_USE_CUDA)
|
||||
SyncArgs* args = new SyncArgs{this, n};
|
||||
SyncArgs* args = new SyncArgs{this, allow_n_pending};
|
||||
cudaLaunchHostFunc((cudaStream_t)user_cuda_stream, (cudaHostFn_t)&sync_, (void*)args);
|
||||
#endif
|
||||
}
|
||||
|
|
|
|||
|
|
@ -42,7 +42,11 @@ void TaskQueue::enqueue(std::function<void()> task) {
|
|||
prev->next.store(node, std::memory_order_release);
|
||||
}
|
||||
|
||||
void TaskQueue::sync(size_t n) { while (pending.load(std::memory_order_acquire) > n); }
|
||||
void TaskQueue::sync(size_t allow_n_pending) {
|
||||
// Spin until the pending task count drops to the allowed threshold.
|
||||
while (pending.load(std::memory_order_acquire) > allow_n_pending)
|
||||
;
|
||||
}
|
||||
|
||||
void TaskQueue::worker() {
|
||||
Node* curr = head.load(std::memory_order_relaxed);
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ class TaskQueue {
|
|||
|
||||
void enqueue(std::function<void()>);
|
||||
|
||||
void sync(size_t n);
|
||||
void sync(size_t allow_n_pending);
|
||||
|
||||
private:
|
||||
struct Node {
|
||||
|
|
|
|||
|
|
@ -258,10 +258,11 @@ PYBIND11_MODULE(kt_kernel_ext, m) {
|
|||
.def(py::init<int>())
|
||||
.def(py::init<WorkerPoolConfig>())
|
||||
.def("submit", &CPUInfer::submit)
|
||||
.def("sync", &CPUInfer::sync, py::arg("n") = 0)
|
||||
.def("sync", &CPUInfer::sync, py::arg("allow_n_pending") = 0)
|
||||
.def_readwrite("backend_", &CPUInfer::backend_)
|
||||
#ifndef KTRANSFORMERS_CPU_ONLY
|
||||
.def("sync_with_cuda_stream", &CPUInfer::sync_with_cuda_stream, py::arg("user_cuda_stream"), py::arg("n") = 0)
|
||||
.def("sync_with_cuda_stream", &CPUInfer::sync_with_cuda_stream, py::arg("user_cuda_stream"),
|
||||
py::arg("allow_n_pending") = 0)
|
||||
.def("submit_with_cuda_stream", &CPUInfer::submit_with_cuda_stream)
|
||||
#endif
|
||||
;
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ implementations, handling weight loading, buffer management, and forward inferen
|
|||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
from typing import List, Dict
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from safetensors import safe_open
|
||||
import os
|
||||
import ctypes
|
||||
|
|
@ -146,30 +146,60 @@ class KExpertsCPUBuffer:
|
|||
capture_buffers: Dict = dict()
|
||||
temp_bs: int = 0
|
||||
temp_buffer: tuple = tuple()
|
||||
buffer_depth: int = 2
|
||||
|
||||
@classmethod
|
||||
def get_buffer(cls, hidden_states: torch.Tensor, num_experts_per_tok):
|
||||
hidden_size = hidden_states.shape[-1]
|
||||
hidden_states = hidden_states.view(-1, hidden_size)
|
||||
batch_size, hidden_size = hidden_states.shape
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
if batch_size in KExpertsCPUBuffer.capture_buffers:
|
||||
return KExpertsCPUBuffer.capture_buffers[batch_size]
|
||||
if batch_size == KExpertsCPUBuffer.temp_bs:
|
||||
return KExpertsCPUBuffer.temp_buffer
|
||||
if batch_size in cls.capture_buffers:
|
||||
return cls.capture_buffers[batch_size]
|
||||
if batch_size == cls.temp_bs:
|
||||
return cls.temp_buffer
|
||||
|
||||
input_tensor_cpu = torch.zeros((batch_size, hidden_size), device="cpu", pin_memory=True, dtype=torch.bfloat16)
|
||||
expert_ids_cpu = torch.zeros((batch_size, num_experts_per_tok), device="cpu", dtype=torch.long, pin_memory=True)
|
||||
weights_cpu = torch.zeros((batch_size, num_experts_per_tok), device="cpu", dtype=torch.float32, pin_memory=True)
|
||||
output_cpu = torch.zeros((batch_size, hidden_size), device="cpu", pin_memory=True, dtype=torch.bfloat16)
|
||||
bsz_tensor_cpu = torch.tensor((batch_size), device="cpu", dtype=torch.int32, pin_memory=True)
|
||||
output_gpu = torch.zeros_like(hidden_states)
|
||||
input_tensor_cpu = [
|
||||
torch.zeros((batch_size, hidden_size), device="cpu", pin_memory=True, dtype=torch.bfloat16)
|
||||
for _ in range(cls.buffer_depth)
|
||||
]
|
||||
immediate_experts_ids_cpu = [
|
||||
torch.zeros((batch_size, num_experts_per_tok), device="cpu", dtype=torch.long, pin_memory=True)
|
||||
for _ in range(cls.buffer_depth)
|
||||
]
|
||||
deferred_experts_ids_cpu = [
|
||||
torch.full((batch_size, num_experts_per_tok), -1, device="cpu", dtype=torch.long, pin_memory=True)
|
||||
for _ in range(cls.buffer_depth)
|
||||
]
|
||||
weights_cpu = [
|
||||
torch.zeros((batch_size, num_experts_per_tok), device="cpu", dtype=torch.float32, pin_memory=True)
|
||||
for _ in range(cls.buffer_depth)
|
||||
]
|
||||
output_cpu = [
|
||||
torch.zeros((batch_size, hidden_size), device="cpu", pin_memory=True, dtype=torch.bfloat16)
|
||||
for _ in range(cls.buffer_depth)
|
||||
]
|
||||
bsz_tensor_cpu = [
|
||||
torch.zeros((1,), device="cpu", dtype=torch.int32, pin_memory=True)
|
||||
for _ in range(cls.buffer_depth)
|
||||
]
|
||||
output_gpu = [
|
||||
torch.zeros((batch_size, hidden_size), device=hidden_states.device, dtype=hidden_states.dtype)
|
||||
for _ in range(cls.buffer_depth)
|
||||
]
|
||||
|
||||
cur_buffer = (input_tensor_cpu, expert_ids_cpu, weights_cpu, output_cpu, bsz_tensor_cpu, output_gpu)
|
||||
if batch_size in KExpertsCPUBuffer.capture_bs:
|
||||
KExpertsCPUBuffer.capture_buffers[batch_size] = cur_buffer
|
||||
KExpertsCPUBuffer.temp_bs = batch_size
|
||||
KExpertsCPUBuffer.temp_buffer = cur_buffer
|
||||
cur_buffer = (
|
||||
input_tensor_cpu,
|
||||
immediate_experts_ids_cpu,
|
||||
deferred_experts_ids_cpu,
|
||||
weights_cpu,
|
||||
output_cpu,
|
||||
bsz_tensor_cpu,
|
||||
output_gpu,
|
||||
)
|
||||
if batch_size in cls.capture_bs:
|
||||
cls.capture_buffers[batch_size] = cur_buffer
|
||||
cls.temp_bs = batch_size
|
||||
cls.temp_buffer = cur_buffer
|
||||
return cur_buffer
|
||||
|
||||
|
||||
|
|
@ -181,6 +211,7 @@ class AMXMoEWrapper:
|
|||
|
||||
_cpu_infer_instance = None
|
||||
_safetensor_loader_instance = None
|
||||
_layer_has_pending_deferred: Dict[int, bool] = {}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -195,6 +226,7 @@ class AMXMoEWrapper:
|
|||
amx_weight_path: str,
|
||||
chunked_prefill_size: int,
|
||||
cpu_save: bool = False,
|
||||
max_deferred_experts_per_token: Optional[int] = None,
|
||||
amx_method: str = "AMXINT4",
|
||||
):
|
||||
"""
|
||||
|
|
@ -212,6 +244,7 @@ class AMXMoEWrapper:
|
|||
amx_weight_path: Path to AMX weights
|
||||
chunked_prefill_size: Maximum prefill chunk size
|
||||
cpu_save: Whether to save weights to CPU memory
|
||||
max_deferred_experts_per_token: Number of experts per token to defer on this layer. Defaults to 0 (no defer).
|
||||
amx_method: AMX quantization method ("AMXINT4" or "AMXINT8")
|
||||
"""
|
||||
|
||||
|
|
@ -224,6 +257,9 @@ class AMXMoEWrapper:
|
|||
self.amx_weight_path = amx_weight_path
|
||||
self.chunked_prefill_size = chunked_prefill_size
|
||||
self.cpu_save = cpu_save
|
||||
self.max_deferred_experts_per_token = int(max_deferred_experts_per_token) if max_deferred_experts_per_token is not None else 0
|
||||
|
||||
AMXMoEWrapper._layer_has_pending_deferred[self.layer_idx] = False
|
||||
self.amx_method = amx_method
|
||||
|
||||
# Initialize CPU inference engine (singleton)
|
||||
|
|
@ -462,6 +498,36 @@ class AMXMoEWrapper:
|
|||
del self.up_scales
|
||||
del self.down_scales
|
||||
|
||||
def select_deferred_experts(
|
||||
self,
|
||||
expert_ids: torch.Tensor,
|
||||
expert_scores: torch.Tensor,
|
||||
protected_k: int,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
batch, topk = expert_ids.shape
|
||||
device = expert_ids.device
|
||||
|
||||
protected_k = max(0, min(int(protected_k), topk))
|
||||
if protected_k == 0:
|
||||
deferred_ids = expert_ids.clone()
|
||||
immediate_ids = torch.full_like(expert_ids, -1)
|
||||
return immediate_ids, deferred_ids
|
||||
|
||||
topk_result = torch.topk(expert_scores, k=protected_k, dim=-1, largest=True, sorted=False)
|
||||
protected_indices = topk_result.indices
|
||||
protected_ids = torch.gather(expert_ids, -1, protected_indices)
|
||||
|
||||
protected_flag = torch.zeros((self.num_experts,), dtype=torch.int32, device=device)
|
||||
protected_flag.scatter_(0, protected_ids.reshape(-1), 1)
|
||||
|
||||
protected_mask_flat = torch.gather(protected_flag, 0, expert_ids.reshape(-1)).ne(0)
|
||||
protected_mask = protected_mask_flat.view(batch, topk)
|
||||
|
||||
immediate_ids = expert_ids.clone().masked_fill(~protected_mask, -1)
|
||||
deferred_ids = expert_ids.clone().masked_fill(protected_mask, -1)
|
||||
|
||||
return immediate_ids, deferred_ids
|
||||
|
||||
def submit_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
|
|
@ -478,36 +544,72 @@ class AMXMoEWrapper:
|
|||
topk_weights: Top-k expert weights [batch_size, num_experts_per_tok]
|
||||
cuda_stream: CUDA stream for synchronization
|
||||
"""
|
||||
# Get CPU buffers
|
||||
flat_hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||
batch_size = flat_hidden_states.shape[0]
|
||||
|
||||
(
|
||||
input_tensor_cpu,
|
||||
expert_ids_cpu,
|
||||
immediate_experts_ids_cpu,
|
||||
deferred_experts_ids_cpu,
|
||||
weights_cpu,
|
||||
output_cpu,
|
||||
bsz_tensor_cpu,
|
||||
output_gpu,
|
||||
) = KExpertsCPUBuffer.get_buffer(hidden_states, self.num_experts_per_tok)
|
||||
_output_gpu,
|
||||
) = KExpertsCPUBuffer.get_buffer(flat_hidden_states, self.num_experts_per_tok)
|
||||
|
||||
# Copy data to CPU
|
||||
topk_ids = topk_ids.to(torch.long)
|
||||
input_tensor_cpu.copy_(hidden_states, non_blocking=True)
|
||||
expert_ids_cpu.copy_(topk_ids, non_blocking=True)
|
||||
weights_cpu.copy_(topk_weights, non_blocking=True)
|
||||
current_slot = self.layer_idx % KExpertsCPUBuffer.buffer_depth
|
||||
next_slot = (current_slot + 1) % KExpertsCPUBuffer.buffer_depth
|
||||
|
||||
# Submit task
|
||||
bsz_slot_tensor = bsz_tensor_cpu[current_slot]
|
||||
bsz_slot_tensor.fill_(batch_size)
|
||||
deferred_experts_ids_cpu[current_slot].fill_(-1)
|
||||
|
||||
topk_ids_long = topk_ids.to(torch.long)
|
||||
immediate_ids: torch.Tensor
|
||||
deferred_ids: Optional[torch.Tensor]
|
||||
if self.max_deferred_experts_per_token > 0:
|
||||
protected_k = self.num_experts_per_tok - self.max_deferred_experts_per_token
|
||||
|
||||
immediate_ids, deferred_ids = self.select_deferred_experts(topk_ids_long, topk_weights, protected_k)
|
||||
else:
|
||||
immediate_ids = topk_ids_long
|
||||
deferred_ids = None
|
||||
|
||||
input_tensor_cpu[current_slot].copy_(flat_hidden_states, non_blocking=True)
|
||||
weights_cpu[current_slot].copy_(topk_weights, non_blocking=True)
|
||||
immediate_experts_ids_cpu[current_slot].copy_(immediate_ids, non_blocking=True)
|
||||
|
||||
incremental = AMXMoEWrapper._layer_has_pending_deferred.get(self.layer_idx - 1, False)
|
||||
self.cpu_infer.submit_with_cuda_stream(
|
||||
cuda_stream,
|
||||
self.moe.forward_task(
|
||||
bsz_tensor_cpu.data_ptr(),
|
||||
expert_ids_cpu.size(-1),
|
||||
expert_ids_cpu.data_ptr(),
|
||||
weights_cpu.data_ptr(),
|
||||
input_tensor_cpu.data_ptr(),
|
||||
output_cpu.data_ptr(),
|
||||
False,
|
||||
bsz_slot_tensor.data_ptr(),
|
||||
immediate_experts_ids_cpu[current_slot].size(-1),
|
||||
immediate_experts_ids_cpu[current_slot].data_ptr(),
|
||||
weights_cpu[current_slot].data_ptr(),
|
||||
input_tensor_cpu[current_slot].data_ptr(),
|
||||
output_cpu[current_slot].data_ptr(),
|
||||
incremental,
|
||||
),
|
||||
)
|
||||
|
||||
AMXMoEWrapper._layer_has_pending_deferred[self.layer_idx] = False
|
||||
if deferred_ids is not None:
|
||||
deferred_experts_ids_cpu[current_slot].copy_(deferred_ids, non_blocking=True)
|
||||
self.cpu_infer.submit_with_cuda_stream(
|
||||
cuda_stream,
|
||||
self.moe.forward_task(
|
||||
bsz_slot_tensor.data_ptr(),
|
||||
deferred_experts_ids_cpu[current_slot].size(-1),
|
||||
deferred_experts_ids_cpu[current_slot].data_ptr(),
|
||||
weights_cpu[current_slot].data_ptr(),
|
||||
input_tensor_cpu[current_slot].data_ptr(),
|
||||
output_cpu[next_slot].data_ptr(),
|
||||
False,
|
||||
),
|
||||
)
|
||||
AMXMoEWrapper._layer_has_pending_deferred[self.layer_idx] = True
|
||||
|
||||
def sync_forward(self, hidden_states: torch.Tensor, cuda_stream) -> torch.Tensor:
|
||||
"""
|
||||
Synchronize and retrieve forward inference results.
|
||||
|
|
@ -519,18 +621,22 @@ class AMXMoEWrapper:
|
|||
Returns:
|
||||
output_gpu: Output tensor on GPU
|
||||
"""
|
||||
flat_hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||
(
|
||||
input_tensor_cpu,
|
||||
expert_ids_cpu,
|
||||
immediate_experts_ids_cpu,
|
||||
_deferred_experts_ids_cpu,
|
||||
weights_cpu,
|
||||
output_cpu,
|
||||
bsz_tensor_cpu,
|
||||
_bsz_tensor_cpu,
|
||||
output_gpu,
|
||||
) = KExpertsCPUBuffer.get_buffer(hidden_states, self.num_experts_per_tok)
|
||||
) = KExpertsCPUBuffer.get_buffer(flat_hidden_states, self.num_experts_per_tok)
|
||||
|
||||
self.cpu_infer.sync_with_cuda_stream(cuda_stream)
|
||||
output_gpu.copy_(output_cpu, non_blocking=True)
|
||||
return output_gpu
|
||||
current_slot = self.layer_idx % KExpertsCPUBuffer.buffer_depth
|
||||
allow_pending = 1 if AMXMoEWrapper._layer_has_pending_deferred.get(self.layer_idx, False) else 0
|
||||
self.cpu_infer.sync_with_cuda_stream(cuda_stream, allow_pending)
|
||||
output_gpu[current_slot].copy_(output_cpu[current_slot], non_blocking=True)
|
||||
return output_gpu[current_slot]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue