mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-05 20:19:51 +00:00
fix load bug
This commit is contained in:
parent
74bb7fdcf6
commit
27990dc6fb
3 changed files with 4 additions and 2 deletions
|
@ -25,7 +25,6 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext
|
|||
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build", "Debug"))
|
||||
import cpuinfer_ext
|
||||
from cpuinfer_ext.moe import MOEConfig, MOE
|
||||
from cpuinfer_ext.moe import AMX_MOEConfig, AMXBF16_MOE, AMXInt8_MOE
|
||||
import ctypes
|
||||
from ktransformers.util.custom_gguf import GGMLQuantizationType, GGUFLoader
|
||||
from ktransformers.util.utils import InferenceState
|
||||
|
@ -186,6 +185,7 @@ class KExpertsCPU(KExpertsBase):
|
|||
)
|
||||
self.moe = MOE(moe_config)
|
||||
elif self.backend == "AMXBF16":
|
||||
from cpuinfer_ext.moe import AMX_MOEConfig, AMXBF16_MOE
|
||||
assert self.gate_type == GGMLQuantizationType.BF16
|
||||
assert self.up_type == GGMLQuantizationType.BF16
|
||||
assert self.down_type == GGMLQuantizationType.BF16
|
||||
|
@ -203,6 +203,7 @@ class KExpertsCPU(KExpertsBase):
|
|||
self.cpu_infer.submit(self.moe.load_weights())
|
||||
self.cpu_infer.sync()
|
||||
elif self.backend == "AMXInt8":
|
||||
from cpuinfer_ext.moe import AMX_MOEConfig, AMXInt8_MOE
|
||||
assert self.gate_type == GGMLQuantizationType.BF16
|
||||
assert self.up_type == GGMLQuantizationType.BF16
|
||||
assert self.down_type == GGMLQuantizationType.BF16
|
||||
|
|
|
@ -85,7 +85,7 @@ class ModelRunner:
|
|||
elif isinstance(self.model, KQwen2MoeForCausalLM) or isinstance(self.model, KQwen3MoeForCausalLM):
|
||||
self.model.flash_infer_attn_plan(batch, self.bsz_tensor_buf, self.num_tokens_tensor_buf,
|
||||
num_q_heads=self.model.config.num_attention_heads, num_kv_heads=self.model.config.num_key_value_heads,
|
||||
head_dim=self.model.config.head_dim if hasattr(self.model.config, 'head_num') else self.model.config.hidden_size // self.model.config.num_attention_heads,
|
||||
head_dim=self.model.config.head_dim if hasattr(self.model.config, 'head_dim') else self.model.config.hidden_size // self.model.config.num_attention_heads,
|
||||
page_size=self.model.cache.page_size, causal=True,
|
||||
q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16, cuda_graph_idx=cuda_graph_idx)
|
||||
else:
|
||||
|
|
|
@ -7,3 +7,4 @@ cpufeature; sys_platform == 'win32' or sys_platform == 'Windows'
|
|||
protobuf
|
||||
tiktoken
|
||||
blobfile
|
||||
triton==3.3
|
Loading…
Add table
Reference in a new issue