mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-09 22:05:30 +00:00
support smt and glm4
This commit is contained in:
parent
613f0b7c37
commit
590fcb41cd
5 changed files with 95 additions and 7 deletions
|
@ -1388,6 +1388,78 @@ class KSmallthinkerExperts(BaseInjectedModule, KExpertsBase):
|
|||
else:
|
||||
raise ValueError("mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD")
|
||||
|
||||
class KGlm4Experts(BaseInjectedModule, KExpertsBase):
|
||||
def __init__(self,
|
||||
key: str,
|
||||
gguf_loader: GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module,
|
||||
# device: str = "cuda",
|
||||
prefill_device:str = "cuda",
|
||||
prefill_op: str | None = "KExpertsTorch",
|
||||
generate_device: str = "cpu",
|
||||
generate_op: str | None = "KExpertsCPU",
|
||||
**kwargs):
|
||||
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
|
||||
KExpertsBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
|
||||
if generate_op is not None:
|
||||
self.generate_experts = EXPERTS_MAP[generate_op](key, gguf_loader, config, len(orig_module), device=generate_device, **kwargs)
|
||||
else:
|
||||
self.generate_experts = None
|
||||
if prefill_op is not None:
|
||||
self.prefill_experts = None
|
||||
self.gpu_mlp_type = prefill_op
|
||||
self.cpu_mlp_type = generate_op
|
||||
self.mode = InferenceState.UNLOAD
|
||||
|
||||
def load(self, w: dict = None, mode: InferenceState = None, warmup: bool = True):
|
||||
# TODO support w as input
|
||||
if not mode: mode = InferenceState.GENERATE
|
||||
if mode == InferenceState.GENERATE:
|
||||
# self.prefill_experts.unload()
|
||||
self.generate_experts.load(w, warmup=warmup)
|
||||
self.device = self.generate_experts.device
|
||||
self.mode = mode
|
||||
elif mode == InferenceState.PREFILL:
|
||||
self.generate_experts.unload()
|
||||
self.prefill_experts.load(w, warmup=warmup)
|
||||
self.device = self.prefill_experts.device
|
||||
self.mode = mode
|
||||
elif mode == InferenceState.UNLOAD:
|
||||
self.unload()
|
||||
self.mode = mode
|
||||
self.device = self.generate_experts.device
|
||||
else:
|
||||
raise ValueError("mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD")
|
||||
|
||||
def unload(self):
|
||||
if self.generate_experts is not None:
|
||||
self.generate_experts.unload()
|
||||
if self.prefill_experts is not None:
|
||||
self.prefill_experts.unload()
|
||||
self.device = self.generate_experts.device
|
||||
|
||||
def forward(self, input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx=0):
|
||||
if self.mode == InferenceState.GENERATE:
|
||||
assert self.generate_experts is not None, "generate_experts is None"
|
||||
return self.generate_experts.forward(input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx)
|
||||
elif self.mode == InferenceState.PREFILL:
|
||||
assert self.prefill_experts is not None, "prefill_experts is None"
|
||||
return self.prefill_experts.forward(input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx)
|
||||
else:
|
||||
raise ValueError("load or set_inference_mode before forward")
|
||||
|
||||
def set_inference_mode(self, mode: InferenceState):
|
||||
if mode == InferenceState.GENERATE:
|
||||
self.load(mode=InferenceState.GENERATE, warmup=False)
|
||||
elif mode == InferenceState.PREFILL:
|
||||
self.load(mode=InferenceState.PREFILL, warmup=False)
|
||||
elif mode == InferenceState.UNLOAD:
|
||||
self.unload()
|
||||
else:
|
||||
raise ValueError("mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD")
|
||||
|
||||
|
||||
class KQwen2MoeSparseMoeBlockV2(BaseInjectedModule, Qwen2MoeSparseMoeBlock):
|
||||
def forward(self, hidden_states, bsz_tensor, cuda_graph_idx=0):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue