add XPU support for qwen3moe local chat

This commit is contained in:
rnwang04 2025-05-21 18:33:41 +08:00
parent 25893366b6
commit adc0906967
9 changed files with 223 additions and 25 deletions

View file

@ -28,7 +28,7 @@ from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3ForCausalLM
from ktransformers.models.modeling_llama import LlamaForCausalLM
from ktransformers.models.modeling_mixtral import MixtralForCausalLM
from ktransformers.util.utils import prefill_and_generate, get_compute_capability
from ktransformers.util.utils import prefill_and_generate, get_compute_capability, xpu_fp16_model
from ktransformers.server.config.config import Config
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled
from ktransformers.util.vendors import device_manager, get_device, to_device, GPUVendor
@ -78,7 +78,7 @@ def local_chat(
if mode == 'long_context':
assert config.architectures[0] == "LlamaForCausalLM", "only LlamaForCausalLM support long_context mode"
torch.set_default_dtype(torch.float16)
elif torch.xpu.is_available() and config.architectures[0] == "DeepseekV3ForCausalLM":
elif xpu_fp16_model(config):
torch.set_default_dtype(torch.float16)
else:
torch.set_default_dtype(config.torch_dtype)
@ -94,11 +94,16 @@ def local_chat(
config._attn_implementation = "eager"
if "Mixtral" in config.architectures[0]:
config._attn_implementation = "flash_attention_2"
if torch.xpu.is_available():
config._attn_implementation = "eager"
model = custom_models[config.architectures[0]](config)
else:
if torch.xpu.is_available():
attn_implementation = "eager"
else:
attn_implementation = "flash_attention_2"
model = AutoModelForCausalLM.from_config(
config, trust_remote_code=True, attn_implementation="flash_attention_2"
config, trust_remote_code=True, attn_implementation=attn_implementation
)
if optimize_config_path is None: