mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 23:34:35 +00:00
add XPU support for qwen3moe local chat
This commit is contained in:
parent
25893366b6
commit
adc0906967
9 changed files with 223 additions and 25 deletions
|
@ -28,7 +28,7 @@ from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM
|
|||
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3ForCausalLM
|
||||
from ktransformers.models.modeling_llama import LlamaForCausalLM
|
||||
from ktransformers.models.modeling_mixtral import MixtralForCausalLM
|
||||
from ktransformers.util.utils import prefill_and_generate, get_compute_capability
|
||||
from ktransformers.util.utils import prefill_and_generate, get_compute_capability, xpu_fp16_model
|
||||
from ktransformers.server.config.config import Config
|
||||
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled
|
||||
from ktransformers.util.vendors import device_manager, get_device, to_device, GPUVendor
|
||||
|
@ -78,7 +78,7 @@ def local_chat(
|
|||
if mode == 'long_context':
|
||||
assert config.architectures[0] == "LlamaForCausalLM", "only LlamaForCausalLM support long_context mode"
|
||||
torch.set_default_dtype(torch.float16)
|
||||
elif torch.xpu.is_available() and config.architectures[0] == "DeepseekV3ForCausalLM":
|
||||
elif xpu_fp16_model(config):
|
||||
torch.set_default_dtype(torch.float16)
|
||||
else:
|
||||
torch.set_default_dtype(config.torch_dtype)
|
||||
|
@ -94,11 +94,16 @@ def local_chat(
|
|||
config._attn_implementation = "eager"
|
||||
if "Mixtral" in config.architectures[0]:
|
||||
config._attn_implementation = "flash_attention_2"
|
||||
|
||||
if torch.xpu.is_available():
|
||||
config._attn_implementation = "eager"
|
||||
model = custom_models[config.architectures[0]](config)
|
||||
else:
|
||||
if torch.xpu.is_available():
|
||||
attn_implementation = "eager"
|
||||
else:
|
||||
attn_implementation = "flash_attention_2"
|
||||
model = AutoModelForCausalLM.from_config(
|
||||
config, trust_remote_code=True, attn_implementation="flash_attention_2"
|
||||
config, trust_remote_code=True, attn_implementation=attn_implementation
|
||||
)
|
||||
|
||||
if optimize_config_path is None:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue