mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 06:14:58 +00:00
add balance-serve, support concurrence
This commit is contained in:
parent
8d0292aa44
commit
25cee5810e
196 changed files with 22077 additions and 565 deletions
|
@ -37,6 +37,10 @@ import time
|
|||
from ktransformers.operators.cpuinfer import CPUInfer
|
||||
|
||||
|
||||
def deduplicate_and_sort(lst):
|
||||
return sorted(set(lst))
|
||||
#cuda_graphs = [Config().chunk_size]
|
||||
cuda_graphs = deduplicate_and_sort([1, 2, 3, Config().max_batch_size, 64, Config().chunk_size])
|
||||
# class Base(BaseInjectedModule, ABC):
|
||||
class KExpertsBase(ABC):
|
||||
def __init__(self, key: str, gguf_loader: GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, device: str = "cuda", **kwargs):
|
||||
|
@ -112,6 +116,7 @@ class KExpertsBase(ABC):
|
|||
tensors[k] = self.gguf_loader.load_gguf_tensor(key + k, device=device)
|
||||
return tensors
|
||||
|
||||
|
||||
class KExpertsCPU(KExpertsBase):
|
||||
input_tensor_cpu:Tensor = None
|
||||
expert_ids_cpu:Tensor = None
|
||||
|
@ -119,8 +124,8 @@ class KExpertsCPU(KExpertsBase):
|
|||
output_cpu:Tensor = None
|
||||
output_gpu_map:dict = {} # Manage output tensor buffer on different gpu
|
||||
#stream_map:dict = {} # Manage cuda stream on different gpu
|
||||
#gguf_loader:GGUFLoader = None
|
||||
CPU_INFER = None
|
||||
# @TODO add yaml
|
||||
CPU_INFER = CPUInfer(Config().cpu_infer)
|
||||
def __init__(
|
||||
self,
|
||||
key: str,
|
||||
|
@ -133,11 +138,6 @@ class KExpertsCPU(KExpertsBase):
|
|||
**kwargs
|
||||
):
|
||||
super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)
|
||||
if KExpertsCPU.CPU_INFER is None:
|
||||
KExpertsCPU.CPU_INFER = CPUInfer(Config().cpu_infer)
|
||||
#if KExpertsCPU.gguf_loader is None:
|
||||
# KExpertsCPU.gguf_loader = GGUFLoader("/mnt/data/model/DeepseekV3-q4km-gguf")
|
||||
self.gguf_loader = gguf_loader
|
||||
assert device.lower() == "cpu", "KExpertsCPU can only be loaded on CPU"
|
||||
self.n_routed_experts = n_routed_experts
|
||||
self.out_device = out_device
|
||||
|
@ -161,7 +161,7 @@ class KExpertsCPU(KExpertsBase):
|
|||
down_ptr = ctypes.addressof(
|
||||
ctypes.cast(self.down.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents
|
||||
)
|
||||
#print(self.gate_type, self.up_type, self.down_type)
|
||||
# print(self.gate_qtype, self.up_qtype, self.down_qtype)
|
||||
n_routed_experts = self.n_routed_experts
|
||||
# n_routed_experts = len(self.orig_module)
|
||||
moe_config = MOEConfig(
|
||||
|
@ -188,43 +188,83 @@ class KExpertsCPU(KExpertsBase):
|
|||
self.cpu_infer.submit(self.moe.warm_up())
|
||||
self.cpu_infer.sync()
|
||||
if self.out_device not in KExpertsCPU.output_gpu_map:
|
||||
KExpertsCPU.output_gpu_map[self.out_device] = torch.zeros((self.config.hidden_size), device=self.out_device)
|
||||
if isinstance(cuda_graphs, list):
|
||||
KExpertsCPU.output_gpu_map[self.out_device] = [torch.zeros((cuda_graphs[i], self.config.hidden_size), device=self.out_device) for i in range(len(cuda_graphs))]
|
||||
else:
|
||||
KExpertsCPU.output_gpu_map[self.out_device] = torch.zeros((cuda_graphs, self.config.hidden_size), device=self.out_device)
|
||||
if KExpertsCPU.input_tensor_cpu == None:
|
||||
KExpertsCPU.input_tensor_cpu = torch.zeros((self.config.hidden_size), device="cpu", pin_memory=True)
|
||||
KExpertsCPU.expert_ids_cpu = torch.zeros((num_experts_per_tok), device="cpu", dtype=torch.long, pin_memory=True)
|
||||
KExpertsCPU.weights_cpu = torch.zeros((num_experts_per_tok), device="cpu", dtype=torch.float32, pin_memory=True)
|
||||
KExpertsCPU.output_cpu = torch.zeros((self.config.hidden_size), device="cpu", pin_memory=True, dtype=torch.bfloat16)
|
||||
if isinstance(cuda_graphs, list):
|
||||
KExpertsCPU.input_tensor_cpu = [torch.zeros((cuda_graphs[i], self.config.hidden_size), device="cpu", pin_memory=True) for i in range(len(cuda_graphs))]
|
||||
KExpertsCPU.expert_ids_cpu = [torch.zeros((cuda_graphs[i], num_experts_per_tok), device="cpu", dtype=torch.long, pin_memory=True) for i in range(len(cuda_graphs))]
|
||||
KExpertsCPU.weights_cpu = [torch.zeros((cuda_graphs[i], num_experts_per_tok), device="cpu", dtype=torch.float32, pin_memory=True) for i in range(len(cuda_graphs))]
|
||||
KExpertsCPU.output_cpu = [torch.zeros((cuda_graphs[i], self.config.hidden_size), device="cpu", pin_memory=True, dtype=torch.bfloat16) for i in range(len(cuda_graphs))]
|
||||
KExpertsCPU.bsz_tensor_cpu = [torch.zeros((1), device="cpu", dtype=torch.int32, pin_memory=True) for i in range(len(cuda_graphs))]
|
||||
else:
|
||||
KExpertsCPU.input_tensor_cpu = torch.zeros((cuda_graphs, self.config.hidden_size), device="cpu", pin_memory=True)
|
||||
KExpertsCPU.expert_ids_cpu = torch.zeros((cuda_graphs, num_experts_per_tok), device="cpu", dtype=torch.long, pin_memory=True)
|
||||
KExpertsCPU.weights_cpu = torch.zeros((cuda_graphs, num_experts_per_tok), device="cpu", dtype=torch.float32, pin_memory=True)
|
||||
KExpertsCPU.output_cpu = torch.zeros((cuda_graphs, self.config.hidden_size), device="cpu", pin_memory=True, dtype=torch.bfloat16)
|
||||
KExpertsCPU.bsz_tensor_cpu = torch.zeros((1), device="cpu", dtype=torch.int32, pin_memory=True)
|
||||
|
||||
def submit_for_one_decode(self, input_tensor, expert_ids, weights):
|
||||
KExpertsCPU.input_tensor_cpu.copy_(input_tensor, non_blocking=True)
|
||||
KExpertsCPU.expert_ids_cpu.copy_(expert_ids, non_blocking=True)
|
||||
KExpertsCPU.weights_cpu.copy_(weights, non_blocking=True)
|
||||
self.cpu_infer.submit_with_cuda_stream(torch.cuda.current_stream(self.out_device).cuda_stream, self.moe.forward(1, expert_ids.size(0), KExpertsCPU.expert_ids_cpu.data_ptr(), KExpertsCPU.weights_cpu.data_ptr(), KExpertsCPU.input_tensor_cpu.data_ptr(), KExpertsCPU.output_cpu.data_ptr()))
|
||||
|
||||
def sync_for_one_decode(self):
|
||||
self.cpu_infer.sync_with_cuda_stream(torch.cuda.current_stream(self.out_device).cuda_stream)
|
||||
KExpertsCPU.output_gpu_map[self.out_device].copy_(KExpertsCPU.output_cpu, non_blocking=True)
|
||||
return KExpertsCPU.output_gpu_map[self.out_device]
|
||||
|
||||
def forward(self, input_tensor, expert_ids, weights):
|
||||
# generate, capture and run cuda graph
|
||||
# print(expert_ids)
|
||||
if input_tensor.size(0)==1 and torch.cuda.is_current_stream_capturing():
|
||||
# TODO: this branch is unreachable, but the shape of input_tensor([1,hidden_size]) and input_tensor_cpu([hidden_size]) is not compatible
|
||||
#print("capturing experts")
|
||||
def submit_for_one_decode(self, input_tensor, expert_ids, weights, bsz_tensor=None, cuda_graph_idx=0):
|
||||
if bsz_tensor is None:
|
||||
bsz_tensor = torch.ones(1, device=input_tensor.device, dtype=torch.int32)
|
||||
if cuda_graph_idx != -1:
|
||||
KExpertsCPU.input_tensor_cpu[cuda_graph_idx].copy_(input_tensor, non_blocking=True)
|
||||
KExpertsCPU.expert_ids_cpu[cuda_graph_idx].copy_(expert_ids, non_blocking=True)
|
||||
KExpertsCPU.weights_cpu[cuda_graph_idx].copy_(weights, non_blocking=True)
|
||||
KExpertsCPU.bsz_tensor_cpu[cuda_graph_idx].copy_(bsz_tensor, non_blocking=True)
|
||||
self.cpu_infer.submit_with_cuda_stream(torch.cuda.current_stream(self.out_device).cuda_stream, self.moe.forward(1, expert_ids.size(-1), KExpertsCPU.expert_ids_cpu[cuda_graph_idx].data_ptr(), KExpertsCPU.weights_cpu[cuda_graph_idx].data_ptr(), KExpertsCPU.input_tensor_cpu[cuda_graph_idx].data_ptr(), KExpertsCPU.output_cpu[cuda_graph_idx].data_ptr(), KExpertsCPU.bsz_tensor_cpu[cuda_graph_idx].data_ptr()))
|
||||
else:
|
||||
KExpertsCPU.input_tensor_cpu.copy_(input_tensor, non_blocking=True)
|
||||
KExpertsCPU.expert_ids_cpu.copy_(expert_ids, non_blocking=True)
|
||||
KExpertsCPU.weights_cpu.copy_(weights, non_blocking=True)
|
||||
self.cpu_infer.submit_with_cuda_stream(torch.cuda.current_stream().cuda_stream, self.moe.forward(1, expert_ids.size(1), KExpertsCPU.expert_ids_cpu.data_ptr(), KExpertsCPU.weights_cpu.data_ptr(), KExpertsCPU.input_tensor_cpu.data_ptr(), KExpertsCPU.output_cpu.data_ptr()))
|
||||
self.cpu_infer.sync_with_cuda_stream(torch.cuda.current_stream().cuda_stream)
|
||||
KExpertsCPU.bsz_tensor_cpu.copy_(bsz_tensor, non_blocking=True)
|
||||
self.cpu_infer.submit_with_cuda_stream(torch.cuda.current_stream(self.out_device).cuda_stream, self.moe.forward(1, expert_ids.size(-1), KExpertsCPU.expert_ids_cpu.data_ptr(), KExpertsCPU.weights_cpu.data_ptr(), KExpertsCPU.input_tensor_cpu.data_ptr(), KExpertsCPU.output_cpu.data_ptr(), KExpertsCPU.bsz_tensor_cpu.data_ptr()))
|
||||
|
||||
|
||||
def sync_for_one_decode(self, cuda_graph_idx=0):
|
||||
if cuda_graph_idx != -1:
|
||||
self.cpu_infer.sync_with_cuda_stream(torch.cuda.current_stream(self.out_device).cuda_stream)
|
||||
KExpertsCPU.output_gpu_map[self.out_device][cuda_graph_idx].copy_(KExpertsCPU.output_cpu[cuda_graph_idx], non_blocking=True)
|
||||
return KExpertsCPU.output_gpu_map[self.out_device][cuda_graph_idx]
|
||||
else:
|
||||
self.cpu_infer.sync_with_cuda_stream(torch.cuda.current_stream(self.out_device).cuda_stream)
|
||||
KExpertsCPU.output_gpu_map[self.out_device].copy_(KExpertsCPU.output_cpu, non_blocking=True)
|
||||
return KExpertsCPU.output_gpu_map[self.out_device]
|
||||
|
||||
def forward(self, input_tensor, expert_ids, weights, bsz_tensor=None, cuda_graph_idx=0):
|
||||
# generate, capture and run cuda graph
|
||||
# print(expert_ids)
|
||||
if bsz_tensor is None:
|
||||
bsz_tensor = torch.tensor([input_tensor.size(0)], device=input_tensor.device, dtype=torch.int32)
|
||||
if torch.cuda.is_current_stream_capturing():
|
||||
if cuda_graph_idx != -1:
|
||||
KExpertsCPU.input_tensor_cpu[cuda_graph_idx].copy_(input_tensor, non_blocking=True)
|
||||
KExpertsCPU.expert_ids_cpu[cuda_graph_idx].copy_(expert_ids, non_blocking=True)
|
||||
KExpertsCPU.weights_cpu[cuda_graph_idx].copy_(weights, non_blocking=True)
|
||||
KExpertsCPU.bsz_tensor_cpu[cuda_graph_idx].copy_(bsz_tensor, non_blocking=True)
|
||||
self.cpu_infer.submit_with_cuda_stream(torch.cuda.current_stream().cuda_stream, self.moe.forward(expert_ids.size(0), expert_ids.size(-1), KExpertsCPU.expert_ids_cpu[cuda_graph_idx].data_ptr(), KExpertsCPU.weights_cpu[cuda_graph_idx].data_ptr(), KExpertsCPU.input_tensor_cpu[cuda_graph_idx].data_ptr(), KExpertsCPU.output_cpu[cuda_graph_idx].data_ptr(), KExpertsCPU.bsz_tensor_cpu[cuda_graph_idx].data_ptr()))
|
||||
self.cpu_infer.sync_with_cuda_stream(torch.cuda.current_stream().cuda_stream)
|
||||
KExpertsCPU.output_gpu_map[self.out_device][cuda_graph_idx].copy_(KExpertsCPU.output_cpu[cuda_graph_idx], non_blocking=True)
|
||||
return KExpertsCPU.output_gpu_map[self.out_device][cuda_graph_idx]
|
||||
|
||||
else:
|
||||
KExpertsCPU.input_tensor_cpu.copy_(input_tensor, non_blocking=True)
|
||||
KExpertsCPU.expert_ids_cpu.copy_(expert_ids, non_blocking=True)
|
||||
KExpertsCPU.weights_cpu.copy_(weights, non_blocking=True)
|
||||
KExpertsCPU.bsz_tensor_cpu.copy_(bsz_tensor, non_blocking=True)
|
||||
self.cpu_infer.submit_with_cuda_stream(torch.cuda.current_stream().cuda_stream, self.moe.forward(expert_ids.size(0), expert_ids.size(-1), KExpertsCPU.expert_ids_cpu.data_ptr(), KExpertsCPU.weights_cpu.data_ptr(), KExpertsCPU.input_tensor_cpu.data_ptr(), KExpertsCPU.output_cpu.data_ptr(), KExpertsCPU.bsz_tensor_cpu.data_ptr()))
|
||||
self.cpu_infer.sync_with_cuda_stream(torch.cuda.current_stream().cuda_stream)
|
||||
KExpertsCPU.output_gpu_map[self.out_device].copy_(KExpertsCPU.output_cpu, non_blocking=True)
|
||||
return KExpertsCPU.output_gpu_map[self.out_device]
|
||||
else:
|
||||
input_tensor = input_tensor.contiguous().cpu()
|
||||
expert_ids = expert_ids.contiguous().cpu()
|
||||
weights = weights.contiguous().to(torch.float32).cpu()
|
||||
bsz_tensor = bsz_tensor.contiguous().cpu()
|
||||
output = torch.empty_like(input_tensor).contiguous()
|
||||
self.cpu_infer.submit(self.moe.forward(expert_ids.size(0), expert_ids.size(1), expert_ids.data_ptr(), weights.data_ptr(), input_tensor.data_ptr(), output.data_ptr()))
|
||||
self.cpu_infer.submit(self.moe.forward(expert_ids.size(0), expert_ids.size(1), expert_ids.data_ptr(), weights.data_ptr(), input_tensor.data_ptr(), output.data_ptr(), bsz_tensor.data_ptr()))
|
||||
self.cpu_infer.sync()
|
||||
return output.to(device=object.__getattribute__(self, "out_device"))
|
||||
|
||||
|
@ -859,6 +899,8 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):
|
|||
y += y_
|
||||
return y
|
||||
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def moe_kexperts(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
|
||||
outs = self.experts(x, topk_ids, topk_weight)
|
||||
|
@ -1013,4 +1055,178 @@ class KMistralSparseMoEBlock(BaseInjectedModule, MixtralSparseMoeBlock):
|
|||
# the `top_x` tensor here.
|
||||
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states_cpu.dtype))
|
||||
|
||||
return final_hidden_states
|
||||
return final_hidden_states
|
||||
|
||||
class KDeepseekV3MoEV2(BaseInjectedModule, DeepseekV3MoE):
|
||||
def forward(self, hidden_states, bsz_tensor, cuda_graph_idx=0):
|
||||
identity = hidden_states
|
||||
orig_shape = hidden_states.shape
|
||||
sequence_length = orig_shape[1]
|
||||
topk_idx, topk_weight = self.gate(hidden_states)
|
||||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||
|
||||
|
||||
# only for generate phase
|
||||
if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug
|
||||
self.experts.generate_experts.submit_for_one_decode(hidden_states, topk_idx, topk_weight, bsz_tensor, cuda_graph_idx)
|
||||
if self.config.n_shared_experts is not None:
|
||||
y_ = self.shared_experts(identity, bsz_tensor).squeeze(0)
|
||||
y = self.experts.generate_experts.sync_for_one_decode(cuda_graph_idx).unsqueeze(0)
|
||||
y += y_
|
||||
y.resize_(*orig_shape)
|
||||
return y
|
||||
|
||||
if self.config.n_shared_experts is not None:
|
||||
y_ = self.shared_experts(identity, bsz_tensor).squeeze(0)
|
||||
|
||||
if isinstance(self.experts, KExpertsBase):
|
||||
y = self.moe_on_cpuinfer(hidden_states, topk_idx, topk_weight, bsz_tensor, cuda_graph_idx).view(*orig_shape).to(device=hidden_states.device)
|
||||
elif hidden_states.size(0) > 10:
|
||||
# TODO may bugs here
|
||||
y = (
|
||||
self.moe_infer(hidden_states, topk_idx, topk_weight)
|
||||
.view(*orig_shape)
|
||||
.to(device=hidden_states.device)
|
||||
)
|
||||
else:
|
||||
# TODO may bugs here
|
||||
y = (
|
||||
self.moe_infer_simple(hidden_states, topk_idx, topk_weight)
|
||||
.view(*orig_shape)
|
||||
.to(device=hidden_states.device)
|
||||
)
|
||||
if self.config.n_shared_experts is not None:
|
||||
y += y_
|
||||
return y
|
||||
|
||||
@torch.no_grad()
|
||||
def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor, bsz_tensor, cuda_graph_idx=0) -> torch.Tensor:
|
||||
outs = torch.empty_like(x)
|
||||
outs = self.experts(x, topk_ids, topk_weight, bsz_tensor, cuda_graph_idx)
|
||||
return outs
|
||||
|
||||
@torch.no_grad()
|
||||
# TODO may bugs here
|
||||
def moe_infer_simple(
|
||||
self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
x: [num_tokens, hidden_size]
|
||||
topk_ids, topk_weight: [num_tokens, num_selected_experts]
|
||||
"""
|
||||
outs = torch.zeros_like(x)
|
||||
for token_idx in range(topk_ids.size(0)):
|
||||
for expert_idx in range(topk_ids.size(1)):
|
||||
expert = self.experts[topk_ids[token_idx, expert_idx]]
|
||||
outs[token_idx] += (
|
||||
expert.forward(x[token_idx]) * topk_weight[token_idx, expert_idx]
|
||||
)
|
||||
return outs
|
||||
|
||||
@torch.no_grad()
|
||||
# TODO may bugs here
|
||||
def moe_infer(self, x, topk_ids, topk_weight):
|
||||
cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))
|
||||
cnts.scatter_(1, topk_ids, 1)
|
||||
tokens_per_expert = cnts.sum(dim=0)
|
||||
idxs = topk_ids.view(-1).argsort()
|
||||
sorted_tokens = x[idxs // topk_ids.shape[1]]
|
||||
tokens_per_expert = tokens_per_expert.cpu().numpy()
|
||||
|
||||
outputs = []
|
||||
start_idx = 0
|
||||
for i, num_tokens in enumerate(tokens_per_expert):
|
||||
end_idx = start_idx + num_tokens
|
||||
if num_tokens == 0:
|
||||
continue
|
||||
expert = self.experts[i + self.ep_rank * self.experts_per_rank]
|
||||
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
|
||||
expert_out = expert.forward(tokens_for_this_expert)
|
||||
outputs.append(expert_out)
|
||||
start_idx = end_idx
|
||||
|
||||
outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
|
||||
|
||||
new_x = torch.empty_like(outs)
|
||||
new_x[idxs] = outs
|
||||
final_out = (
|
||||
new_x.view(*topk_ids.shape, -1)
|
||||
.type(topk_weight.dtype)
|
||||
.mul_(topk_weight.unsqueeze(dim=-1))
|
||||
.sum(dim=1)
|
||||
.type(new_x.dtype)
|
||||
)
|
||||
return final_out
|
||||
|
||||
class KTransformersExpertsV2(BaseInjectedModule, KExpertsBase):
|
||||
def __init__(self,
|
||||
key: str,
|
||||
gguf_loader: GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module,
|
||||
# device: str = "cuda",
|
||||
prefill_device:str = "cuda",
|
||||
prefill_op: str | None = "KExpertsTorch",
|
||||
generate_device: str = "cpu",
|
||||
generate_op: str | None = "KExpertsCPU",
|
||||
**kwargs):
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
|
||||
KExpertsBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
|
||||
if generate_op is not None:
|
||||
self.generate_experts = EXPERTS_MAP[generate_op](key, gguf_loader, config, len(orig_module), device=generate_device, **kwargs)
|
||||
else:
|
||||
self.generate_experts = None
|
||||
if prefill_op is not None:
|
||||
self.prefill_experts = EXPERTS_MAP[prefill_op](key, gguf_loader, config, len(orig_module), device=prefill_device, **kwargs)
|
||||
else:
|
||||
self.prefill_experts = None
|
||||
self.gpu_mlp_type = prefill_op
|
||||
self.cpu_mlp_type = generate_op
|
||||
self.mode = InferenceState.UNLOAD
|
||||
|
||||
def load(self, w: dict = None, mode: InferenceState = None, warmup: bool = True):
|
||||
# TODO support w as input
|
||||
if not mode: mode = InferenceState.GENERATE
|
||||
if mode == InferenceState.GENERATE:
|
||||
self.prefill_experts.unload()
|
||||
self.generate_experts.load(w, warmup=warmup)
|
||||
self.device = self.generate_experts.device
|
||||
self.mode = mode
|
||||
elif mode == InferenceState.PREFILL:
|
||||
self.generate_experts.unload()
|
||||
self.prefill_experts.load(w, warmup=warmup)
|
||||
self.device = self.prefill_experts.device
|
||||
self.mode = mode
|
||||
elif mode == InferenceState.UNLOAD:
|
||||
self.unload()
|
||||
self.mode = mode
|
||||
self.device = self.generate_experts.device
|
||||
else:
|
||||
raise ValueError("mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD")
|
||||
|
||||
def unload(self):
|
||||
if self.generate_experts is not None:
|
||||
self.generate_experts.unload()
|
||||
if self.prefill_experts is not None:
|
||||
self.prefill_experts.unload()
|
||||
self.device = self.generate_experts.device
|
||||
|
||||
def forward(self, input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx=0):
|
||||
if self.mode == InferenceState.GENERATE:
|
||||
assert self.generate_experts is not None, "generate_experts is None"
|
||||
return self.generate_experts.forward(input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx)
|
||||
elif self.mode == InferenceState.PREFILL:
|
||||
assert self.prefill_experts is not None, "prefill_experts is None"
|
||||
return self.prefill_experts.forward(input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx)
|
||||
else:
|
||||
raise ValueError("load or set_inference_mode before forward")
|
||||
|
||||
def set_inference_mode(self, mode: InferenceState):
|
||||
if mode == InferenceState.GENERATE:
|
||||
self.load(mode=InferenceState.GENERATE, warmup=False)
|
||||
elif mode == InferenceState.PREFILL:
|
||||
self.load(mode=InferenceState.PREFILL, warmup=False)
|
||||
elif mode == InferenceState.UNLOAD:
|
||||
self.unload()
|
||||
else:
|
||||
raise ValueError("mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue